{"version":"https://jsonfeed.org/version/1","title":"Henry Zhu","description":"Personal website for some random tidbits I work on\n","home_page_url":"https://maknee.github.io/","feed_url":"https://maknee.github.io/feed.json","items":[{"id":"https://maknee.github.io/blog/2026/NVIDIA-TileIR-Internals-from-CuTile-to-MLIR-LLVM-to-SASS","url":"https://maknee.github.io/blog/2026/NVIDIA-TileIR-Internals-from-CuTile-to-MLIR-LLVM-to-SASS/","title":"NVIDIA TileIR Internals: from CuTile to MLIR/LLVM to SASS","content_html":"\u003cp\u003eIn this post, we’ll dig deep into how TileIR works, from how it generates instructions to analyzing its different passes. We’ll trace how a Mixture-of-Experts (MoE) kernel written in CuTile gets compiled down through \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e → NVVM → LLVM → SASS.\u003c/p\u003e\n\n\u003cp\u003eHere’s what to expect:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003ca href=\"#what-is-cutile\"\u003e\u003cstrong\u003eWhat is CuTile?\u003c/strong\u003e\u003c/a\u003e — The tile-centric programming model\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#running-example-moe-kernel\"\u003e\u003cstrong\u003eRunning Example\u003c/strong\u003e\u003c/a\u003e — An MoE kernel we’ll trace through every stage\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#the-dialects\"\u003e\u003cstrong\u003eThe Dialects\u003c/strong\u003e\u003c/a\u003e — From \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile\u003c/code\u003e through \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e to NVVM/LLVM\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#the-passes\"\u003e\u003cstrong\u003eThe Passes\u003c/strong\u003e\u003c/a\u003e — TileIR passes: what they do and when they run\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cem\u003eBased on CUDA 13.1. Some details are undocumented and may change in future releases.\u003c/em\u003e\u003c/p\u003e\n\n\u003ch1 id=\"what-is-cutile\"\u003eWhat is CuTile?\u003c/h1\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/cutile.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eCuTile separates user responsibility (splitting work into blocks and tiles) from system responsibility (mapping to threads)\n \n (Image source: \u003ca href=\"https://youtu.be/_b4I4rKpsGA?t=406\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eGPU MODE\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003e\u003ca href=\"https://github.com/NVIDIA/cutile-python\"\u003eCuTile\u003c/a\u003e is NVIDIA’s new “tile-centric” programming model for modern NVIDIA GPUs. This abstraction is powerful: CuTile lets the programmer think in terms of tiles rather than threads, while the compiler handles the complexity of coordinating hundreds of threads across fragmented data. A single CuTile line \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ect.mma(a, b, acc)\u003c/code\u003e could get transformed to many tensor core instructions.\u003c/p\u003e\n\n\u003ch2 id=\"what-is-tileir\"\u003eWhat is TileIR?\u003c/h2\u003e\n\n\u003cp\u003eTileIR is NVIDIA’s MLIR-based compiler infrastructure that powers CuTile. It progressively lowers your high-level tensor operations through multiple MLIR dialects and NVIDIA specific tools:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/pipeline_overview.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eTileIR compilation pipeline: Python → SASS\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe user-facing tool is \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileiras\u003c/code\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eLike \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eptxas\u003c/code\u003e but for TileIR. Yes, NVIDIA named it “tile-ir-as” (tile IR assembler).\u003c/span\u003e, which orchestrates this entire pipeline.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"running-example-moe-kernel\"\u003eRunning Example: MoE Kernel\u003c/h1\u003e\n\n\u003cp\u003eThroughout this post, we’ll trace this \u003cstrong\u003eMoE (Mixture of Experts) kernel\u003c/strong\u003e through every compilation stage. This is code from \u003ca href=\"https://github.com/NVIDIA/cutile-python/blob/main/samples/MoE.py\"\u003eNVIDIA’s cutile-python samples\u003c/a\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThere’s also a C++ API: \u003ca href=\"https://github.com/NVIDIA/cuda-tile\"\u003eNVIDIA/cuda-tile\u003c/a\u003e. Operations like \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ect.gather\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ect.mma\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.load_view_tko\u003c/code\u003e documented in \u003ca href=\"https://docs.nvidia.com/cuda/tile-ir/13.1/sections/operations.html\"\u003eTileIR docs\u003c/a\u003e.\u003c/span\u003e:\u003c/p\u003e\n\n\u003cdiv class=\"language-python highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"nd\"\u003e@ct.kernel\u003c/span\u003e\n\u003cspan class=\"k\"\u003edef\u003c/span\u003e \u003cspan class=\"nf\"\u003efused_moe_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"n\"\u003eA\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Input tokens, shape (batch, K)\n\u003c/span\u003e \u003cspan class=\"n\"\u003eB\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Expert weights, shape (num_experts, N, K)\n\u003c/span\u003e \u003cspan class=\"n\"\u003eC\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Output tensor, shape (num_tokens * topk, N)\n\u003c/span\u003e \u003cspan class=\"n\"\u003etopk_weights\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Router weights for each token-expert pair\n\u003c/span\u003e \u003cspan class=\"n\"\u003esorted_token_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Token indices sorted by expert assignment\n\u003c/span\u003e \u003cspan class=\"n\"\u003esorted_expert_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Expert index for each TILE_M\n\u003c/span\u003e \u003cspan class=\"n\"\u003enum_token_replicas\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"nb\"\u003eint\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003emul_routed_weight\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eConstBool\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003eTILE_M\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eConstInt\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eConstInt\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eConstInt\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n\u003cspan class=\"p\"\u003e):\u003c/span\u003e\n \u003cspan class=\"n\"\u003eM\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003esorted_token_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eshape\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n \u003cspan class=\"n\"\u003eN\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eB\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eshape\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n \u003cspan class=\"n\"\u003eK\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eB\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eshape\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eGROUP_SIZE_M\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e8\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebid_m\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebid_n\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nf\"\u003eswizzle_2d\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_M\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eGROUP_SIZE_M\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.get_tile_block_id\n\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Gather token indices for this block\n\u003c/span\u003e \u003cspan class=\"n\"\u003etoken_id_indices\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebid_m\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_M\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003earange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eTILE_M\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003edtype\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eint32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003etoken_ids\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003egather\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esorted_token_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etoken_id_indices\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.load_view_tko\n\u003c/span\u003e \u003cspan class=\"n\"\u003ea_row_indices\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003etoken_ids\u003c/span\u003e \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"n\"\u003enum_token_replicas\u003c/span\u003e\n \u003cspan class=\"n\"\u003eexpert_id\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003eload\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esorted_expert_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eindex\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003ebid_m\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eshape\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.load_ptr_tko\n\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Initialize accumulator\n\u003c/span\u003e \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003efull\u003c/span\u003e\u003cspan class=\"p\"\u003e((\u003c/span\u003e\u003cspan class=\"n\"\u003eTILE_M\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"mf\"\u003e0.0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003edtype\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003efloat32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.constant\n\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003ek\u003c/span\u003e \u003cspan class=\"ow\"\u003ein\u003c/span\u003e \u003cspan class=\"nf\"\u003erange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003ecdiv\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eK\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e\u003cspan class=\"p\"\u003e)):\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.for\n\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Load A tile (gathered by token indices)\n\u003c/span\u003e \u003cspan class=\"n\"\u003ea_col_indices\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ek\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003earange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003edtype\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eint32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003egather\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eA\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ea_row_indices\u003c/span\u003e\u003cspan class=\"p\"\u003e[:,\u003c/span\u003e \u003cspan class=\"bp\"\u003eNone\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"n\"\u003ea_col_indices\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"bp\"\u003eNone\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e:]))\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.load_view_tko\n\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Load B tile (expert weights)\n\u003c/span\u003e \u003cspan class=\"n\"\u003eb\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003eload\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eB\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eexpert_id\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ek\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebid_n\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"n\"\u003eshape\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003eorder\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"mi\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)).\u003c/span\u003e\u003cspan class=\"nf\"\u003ereshape\u003c/span\u003e\u003cspan class=\"p\"\u003e((\u003c/span\u003e\u003cspan class=\"n\"\u003eTILE_K\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.load_ptr_tko\n\u003c/span\u003e\n \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003emma\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ea\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eb\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.mmaf ← THE COMPUTE!\n\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"n\"\u003emul_routed_weight\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003emoe_weight\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003egather\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etopk_weights\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etoken_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003emoe_weight\u003c/span\u003e\u003cspan class=\"p\"\u003e[:,\u003c/span\u003e \u003cspan class=\"bp\"\u003eNone\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.mulf\n\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Scatter results back to output\n\u003c/span\u003e \u003cspan class=\"n\"\u003ec_col_indices\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebid_n\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003earange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eTILE_N\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003edtype\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eint32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003eastype\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eC\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003edtype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.ftof\n\u003c/span\u003e \u003cspan class=\"n\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003escatter\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eC\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etoken_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e[:,\u003c/span\u003e \u003cspan class=\"bp\"\u003eNone\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"n\"\u003ec_col_indices\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"bp\"\u003eNone\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e:]),\u003c/span\u003e \u003cspan class=\"n\"\u003eaccumulator\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"c1\"\u003e# → cuda_tile.store_ptr_tko\n\u003c/span\u003e\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003e\u003cstrong\u003eThe three key operations we’ll trace:\u003c/strong\u003e\u003c/p\u003e\n\n\u003ctable style=\"width: 100%; border-collapse: collapse; font-family: monospace; font-size: 0.9em;\"\u003e\n\u003cthead\u003e\n\u003ctr style=\"background: linear-gradient(135deg, #1a1a2e 0%, #16213e 100%);\"\u003e\n\u003cth style=\"padding: 12px; text-align: left; color: #76b900; border-bottom: 2px solid #76b900;\"\u003ePython\u003c/th\u003e\n\u003cth style=\"padding: 12px; text-align: left; color: #76b900; border-bottom: 2px solid #76b900;\"\u003ecuda_tile\u003c/th\u003e\n\u003cth style=\"padding: 12px; text-align: left; color: #76b900; border-bottom: 2px solid #76b900;\"\u003eWhat it does\u003c/th\u003e\n\u003c/tr\u003e\n\u003c/thead\u003e\n\u003ctbody\u003e\n\u003ctr style=\"background: rgba(118, 185, 0, 0.1);\"\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333;\"\u003ect.gather(A, indices)\u003c/td\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333;\"\u003eload_view_tko\u003c/td\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333; font-family: sans-serif;\"\u003eGather tokens by expert assignment (indirect load)\u003c/td\u003e\n\u003c/tr\u003e\n\u003ctr style=\"background: rgba(0, 150, 255, 0.1);\"\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333;\"\u003ect.load(B, ...)\u003c/td\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333;\"\u003eload_ptr_tko\u003c/td\u003e\n\u003ctd style=\"padding: 10px; border-bottom: 1px solid #333; font-family: sans-serif;\"\u003eLoad expert weights (direct load)\u003c/td\u003e\n\u003c/tr\u003e\n\u003ctr style=\"background: rgba(255, 100, 100, 0.1);\"\u003e\n\u003ctd style=\"padding: 10px;\"\u003ect.mma(a, b, acc)\u003c/td\u003e\n\u003ctd style=\"padding: 10px;\"\u003emmaf\u003c/td\u003e\n\u003ctd style=\"padding: 10px; font-family: sans-serif;\"\u003eMatrix multiply-accumulate on tensor cores\u003c/td\u003e\n\u003c/tr\u003e\n\u003c/tbody\u003e\n\u003c/table\u003e\n\n\u003cp\u003eWatch how these transform through \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e and finally to SASS instructions.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"compiling-with-tileiras\"\u003eCompiling with tileiras\u003c/h1\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileiras\u003c/code\u003e command-line tool is the ahead-of-time compiler that transforms \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e.cutile\u003c/code\u003e bytecode into GPU binaries.\u003c/p\u003e\n\n\u003cdiv class=\"language-bash highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003etileiras \u003cspan class=\"nt\"\u003e--gpu-name\u003c/span\u003e sm_120 MoE.cutile \u003cspan class=\"nt\"\u003e-o\u003c/span\u003e moe.cubin\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch2 id=\"undocumented-environment-variables\"\u003eUndocumented Environment Variables\u003c/h2\u003e\n\n\u003cp\u003eThese TileIR-specific environment variables affect compilation:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"env-vars-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"env-vars-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Variable\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Description\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTILEIR_ALWAYS_SWIZZLE\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eForce swizzle mode\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTILEIR_PREFER_TMA_FOR_LOAD_STORE\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePrefer TMA for all load/store operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTILEIR_DELAY_TMA_STORE_WAIT\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"env-vars-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDelay TMA store wait (optimization for overlapping compute)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003ch2 id=\"interesting-undocumented-cli-options\"\u003eInteresting undocumented CLI options\u003c/h2\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e--print-before-all\u003c/code\u003e flag dumps LLVM IR before each compilation pass.\u003c/p\u003e\n\n\u003cdiv class=\"language-bash highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"nv\"\u003e$ \u003c/span\u003etileiras \u003cspan class=\"nt\"\u003e--print-before-all\u003c/span\u003e \u003cspan class=\"nt\"\u003e--gpu-name\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003esm_120 MoE.cutile \u003cspan class=\"nt\"\u003e-o\u003c/span\u003e moe.cubin 2\u0026gt;\u0026amp;1\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"p\"\u003e***\u003c/span\u003e \u003cspan class=\"err\"\u003eIR\u003c/span\u003e \u003cspan class=\"err\"\u003eDump\u003c/span\u003e \u003cspan class=\"err\"\u003eBefore\u003c/span\u003e \u003cspan class=\"err\"\u003eAdd\u003c/span\u003e \u003cspan class=\"err\"\u003e__emutls_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"err\"\u003evt\u003c/span\u003e\u003cspan class=\"p\"\u003e].\u003c/span\u003e \u003cspan class=\"err\"\u003evariables\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003eemultated\u003c/span\u003e \u003cspan class=\"err\"\u003eTLS\u003c/span\u003e \u003cspan class=\"err\"\u003emodel\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003elower-emutls\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e***\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; ModuleID = 'LLVMDialectModule'\u003c/span\u003e\n\u003cspan class=\"k\"\u003esource_filename\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"LLVMDialectModule\"\u003c/span\u003e\n\u003cspan class=\"k\"\u003etarget\u003c/span\u003e \u003cspan class=\"k\"\u003edatalayout\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128\"\u003c/span\u003e\n\n\u003cspan class=\"vg\"\u003e@__CUDA_TILEIR_FUNC_NAME_0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003einternal\u003c/span\u003e \u003cspan class=\"k\"\u003econstant\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"m\"\u003e17\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003ei8\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"s\"\u003ec\"fused_moe_kernel\\00\"\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cdetails\u003e\n \u003csummary\u003e\u003cstrong\u003eAll LLVM passes dumped (27 unique passes)\u003c/strong\u003e\u003c/summary\u003e\n\n \u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e*** IR Dump Before Add __emutls_[vt]. variables for emultated TLS model (lower-emutls) ***\n*** IR Dump Before Canonicalize natural loops (loop-simplify) ***\n*** IR Dump Before CodeGen Prepare (codegenprepare) ***\n*** IR Dump Before Constant Hoisting (consthoist) ***\n*** IR Dump Before Exception handling preparation (dwarf-eh-prepare) ***\n*** IR Dump Before Expand Atomic instructions (atomic-expand) ***\n*** IR Dump Before Expand fp (expand-fp) ***\n*** IR Dump Before Expand indirectbr instructions (indirectbr-expand) ***\n*** IR Dump Before Expand large div/rem (expand-large-div-rem) ***\n*** IR Dump Before Expand memcmp() to load/stores (expand-memcmp) ***\n*** IR Dump Before Expand reduction intrinsics (expand-reductions) ***\n*** IR Dump Before Instrument function entry/exit with calls to e.g. mcount() (post-inline-ee-instrument) ***\n*** IR Dump Before Interleaved Access Pass (interleaved-access) ***\n*** IR Dump Before Lower AMX intrinsics (lower-amx-intrinsics) ***\n*** IR Dump Before Lower AMX type for load/store (lower-amx-type) ***\n*** IR Dump Before Lower Garbage Collection Instructions (gc-lowering) ***\n*** IR Dump Before Merge contiguous icmps into a memcmp (mergeicmps) ***\n*** IR Dump Before ObjC ARC contraction (objc-arc-contract) ***\n*** IR Dump Before Partially inline calls to library functions (partially-inline-libcalls) ***\n*** IR Dump Before Pre-ISel Intrinsic Lowering (pre-isel-intrinsic-lowering) ***\n*** IR Dump Before Prepare callbr (callbrprepare) ***\n*** IR Dump Before Remove unreachable blocks from the CFG (unreachableblockelim) ***\n*** IR Dump Before Replace intrinsics with calls to vector library (replace-with-veclib) ***\n*** IR Dump Before Safe Stack instrumentation pass (safe-stack) ***\n*** IR Dump Before Scalarize Masked Memory Intrinsics (scalarize-masked-mem-intrin) ***\n*** IR Dump Before Shadow Stack GC Lowering (shadow-stack-gc-lowering) ***\n*** IR Dump Before X86 Partial Reduction (x86-partial-reduction) ***\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"pipeline-overview\"\u003ePipeline Overview\u003c/h1\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/pipeline_overview.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eTileIR compilation pipeline: Python → SASS\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c!-- Excalidraw diagram: Pipeline Flow - Python → cuda_tile → nv_tileaa → nv_tileas → NVVM → LLVM → PTX → SASS --\u003e\n\n\u003cp\u003eTileIR takes your CuTile Python code through a series of progressive lowerings:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"pipeline-stages-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"pipeline-stages-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Stage\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Format\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Description\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePython\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCuTile API\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-level tensor operations (make_tensor_view; mmaf)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e.cutile\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eBytecode\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSerialized representation of the kernel\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecuda_tile\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMLIR Dialect\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-level tensor ops; architecture-independent\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileaa\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMLIR Dialect\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTile-level ops; explicit memory references\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileas\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMLIR Dialect\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eScheduled ops; async pipelines\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLLVM/NVVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLLVM IR\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStandard LLVM with NVIDIA intrinsics\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePTX\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAssembly\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eVirtual GPU assembly\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSASS\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMachine Code\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"pipeline-stages-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNative GPU instructions (sm_120)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eEach stage removes abstraction and adds architecture-specific detail. By the time we reach SASS, every memory access pattern, tensor core instruction, and synchronization barrier is explicit.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"the-dialects\"\u003eThe Dialects\u003c/h1\u003e\n\n\u003cp\u003eTileIR uses three main MLIR dialects to represent computations at different abstraction levels. Let’s trace our MoE kernel through each one:\u003c/p\u003e\n\n\u003c!-- Excalidraw diagram: MoE Operation Mapping - shows gather/load/mma traced through each dialect --\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"moe-ops-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"moe-ops-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Python\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n cuda_tile\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n nv_tileaa\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n nv_tileas\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n SASS\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.gather(A\u0026#44; idx)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eload_view_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileaa.load_view\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas.utcpglobalmem\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eUTCPMULTI / LDG\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.load(B\u0026#44; ...)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eload_ptr_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileaa.load_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas.tcgen05_ld\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row1-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTCGEN05.LD.S\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.mma(a\u0026#44; b\u0026#44; c)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emmaf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileaa.mmaf_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas.tcgen05_mma\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"moe-ops-table-row2-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTCGEN05.MMA\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003ch2 id=\"cuda_tile-high-level-tensor-operations\"\u003ecuda_tile: High-Level Tensor Operations\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/cuda_tile_dialect.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003ecuda_tile dialect operations\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile\u003c/code\u003e dialect is closest to your Python code. Operations work on abstract tensor views without worrying about memory layout or hardware details.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eKey operations:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_tensor_view\u003c/code\u003e - Create a view into a tensor with shape and strides\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eget_tile_block_id\u003c/code\u003e - Get the current thread block’s position in the grid\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eload_view_tko\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003estore_view_tko\u003c/code\u003e - Load/store tiles with token-based ordering\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003emmaf\u003c/code\u003e - Matrix multiply-accumulate (targets tensor cores)\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003efor\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003econtinue\u003c/code\u003e - Loop constructs for K-dimension iteration\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch3 id=\"moe-in-cuda_tile\"\u003eMoE in cuda_tile\u003c/h3\u003e\n\n\u003cp\u003eRecall our \u003ca href=\"#running-example-moe-kernel\"\u003eMoE kernel above\u003c/a\u003e. Here’s how the key operations map to \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile\u003c/code\u003e IR:\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003ePython → cuda_tile mapping:\u003c/strong\u003e\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"python-ir-mapping-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"python-ir-mapping-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Python (CuTile)\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n cuda_tile IR\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Purpose\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.gather()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eload_view_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGather elements by indices\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.load()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eload_ptr_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLoad contiguous tile from memory\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.mma()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emmaf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMatrix multiply-accumulate (tensor cores)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.scatter()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003estore_ptr_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eScatter elements to output\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.full()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econstant\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eInitialize accumulator\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003efor k in range()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003efor/continue\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eK-dimension iteration loop\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.astype()\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eftof\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"python-ir-mapping-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eType conversion (F32 → output dtype)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see cuda_tile IR from MoE kernel key sections\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ecuda_tile\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003eMoE\u003c/span\u003e \u003cspan class=\"err\"\u003ekernel\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTILE_M\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%2\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTILE_N\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%3\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTILE_K\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%4\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%5\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%10\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_tensor_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%4\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%5\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%6\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%7\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%8\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%9\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%11\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_tensor_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%12\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_token\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%21\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%22\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.get_tile_block_id\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%23\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%4\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eM\u003c/span\u003e \u003cspan class=\"err\"\u003e/\u003c/span\u003e \u003cspan class=\"err\"\u003eTILE_M\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%24\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%23\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%25\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%24\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.remi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%25\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eexpert\u003c/span\u003e \u003cspan class=\"err\"\u003erouting\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%32\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.select\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%25\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%40\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.iota\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%24\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%42\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%43\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%42\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%40\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%44\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.offset\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%42\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%43\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%51\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%44\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%42\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"k\"\u003eload\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%52\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_partition_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%10\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%53\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%54\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_view_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%52\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%43\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003egather\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.for\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%23\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg4\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"nl\"\u003eK-loop\n :\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%61\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%iter\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%62\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%61\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%63\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%64\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%62\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%42\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%65\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%66\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_view_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%52\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%62\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%67\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.mmaf\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%63\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%65\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%acc\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emma\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"s\"\u003e\"cuda_tile.continue\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%67\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%70\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.ftof\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eastype\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%71\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.store_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%44\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%70\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003escatter\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"cuda_tile.return\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"nv_tileaa\"\u003env_tileaa\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/nv_tileaa_dialect.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003env_tileaa dialect operations\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e dialect lowers tensor views to concrete memory references. This is where we start seeing explicit memory operations.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eKey changes from cuda_tile:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_tensor_view\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_memref\u003c/code\u003e (explicit memory references)\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eget_tile_block_id\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eget_program_id\u003c/code\u003e (program-centric naming)\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003emmaf\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003edot\u003c/code\u003e (more explicit accumulation)\u003c/li\u003e\n \u003cli\u003eExplicit \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etiled_load\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etiled_store\u003c/code\u003e with memory tokens\u003c/li\u003e\n \u003cli\u003eNew ops: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003esplat\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ebroadcast\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eaddptr\u003c/code\u003e for memory address calculations\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see nv_tileaa IR from MoE kernel key sections\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003eMoE\u003c/span\u003e \u003cspan class=\"err\"\u003ekernel\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTile-level\u003c/span\u003e \u003cspan class=\"err\"\u003eops\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003earchitecture-independent\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"s\"\u003e\"nv_tileaa.func\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ekernel_spec\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eInput\u003c/span\u003e \u003cspan class=\"err\"\u003evalidation\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%2\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%3\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%2\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"nl\"\u003eSplat:\u003c/span\u003e \u003cspan class=\"err\"\u003escalar\u003c/span\u003e \u003cspan class=\"err\"\u003e→\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003ebroadcasting\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%10\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%11\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%2\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eMemory\u003c/span\u003e \u003cspan class=\"err\"\u003ereference\u003c/span\u003e \u003cspan class=\"err\"\u003ecreation\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003elowered\u003c/span\u003e \u003cspan class=\"k\"\u003efrom\u003c/span\u003e \u003cspan class=\"err\"\u003emake_tensor_view\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_memref\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%4\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%5\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%6\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%21\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_memref\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%2\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%22\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.create_mem_token\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eProgram\u003c/span\u003e \u003cspan class=\"err\"\u003eindexing\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.get_program_id\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%32\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_range\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c128\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%33\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ePointer\u003c/span\u003e \u003cspan class=\"err\"\u003earithmetic\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%40\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.addptr\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%40\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%33\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eMasked\u003c/span\u003e \u003cspan class=\"err\"\u003eloads\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%51\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mask\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%22\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTiled\u003c/span\u003e \u003cspan class=\"err\"\u003ememory\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.block_tile\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%61\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%62\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%63\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%61\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%22\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%64\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%62\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eShape\u003c/span\u003e \u003cspan class=\"err\"\u003emanipulation\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%70\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%33\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%71\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%70\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eDOT\u003c/span\u003e \u003cspan class=\"err\"\u003eOPERATION\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003elowered\u003c/span\u003e \u003cspan class=\"k\"\u003efrom\u003c/span\u003e \u003cspan class=\"err\"\u003ecuda_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emmaf\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%80\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%64\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%acc\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eOutput\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%90\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.fp_to_fp\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%80\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%91\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.store\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%90\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mask\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%22\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileaa.return\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003cp\u003e\u003cstrong\u003eKey transformations from cuda_tile → nv_tileaa:\u003c/strong\u003e\u003c/p\u003e\n\n \u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n \u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n \u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n \u003cdiv id=\"dialect-comparison-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"dialect-comparison-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n cuda_tile\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n nv_tileaa\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Change\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emake_tensor_view\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emake_memref\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAbstract view → concrete memory ref\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eget_tile_block_id\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eget_program_id\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTile-centric → program-centric naming\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emmaf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003edot\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-level MMA → explicit dot product\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eload_view_tko\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etiled_load + view\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDecomposed into separate ops\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.view types\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etensor\u0026lt;...\u0026gt;\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAbstract → explicit tensor shapes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ect.token\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eaa.btile; aa.mtoken\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"dialect-comparison-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMemory tokens more specific\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n \u003cp\u003e\u003cstrong\u003ePass #12 observation:\u003c/strong\u003e The 32 \u003ccode class=\"language-plaintext highlighter-rouge\"\u003efp_to_fp\u003c/code\u003e operations suggest this MoE kernel produces 32 output tiles that need precision conversion from F32 accumulator to the output dtype.\u003c/p\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"nv_tileas\"\u003env_tileas\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/nv_tileas_tcgen05.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003env_tileas dialect with tcgen05 operations\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e dialect is where architecture-specific code generation happens.\u003c/p\u003e\n\n\u003cp\u003eThis dialect introduces:\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eAsync Pipeline Operations:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003easync.pipeline.create\u003c/code\u003e - Create a software pipeline for overlapping compute/memory\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eproducer_acquire\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eproducer_commit\u003c/code\u003e - Acquire/release pipeline stages\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003econsumer_wait\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003econsumer_release\u003c/code\u003e - Synchronize consumers with producers\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eTensor Memory Operations:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etcgen05.alloc\u003c/code\u003e - Allocate dedicated tensor memory\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etmem_load\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etmem_store\u003c/code\u003e - Access tensor memory\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eTensor Core Operations:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etcgen05.mma\u003c/code\u003e - Matrix Multiply-Accumulate\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eblock_scaled_mma\u003c/code\u003e - Block-scaled MMA for mixed precision\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003emma.fence\u003c/code\u003e - Memory fence for MMA operations\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see nv_tileas IR from MoE kernel key sections\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileas\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003eMoE\u003c/span\u003e \u003cspan class=\"err\"\u003ekernel\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTile-level\u003c/span\u003e \u003cspan class=\"err\"\u003eScheduled\u003c/span\u003e \u003cspan class=\"err\"\u003eAssembly\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eLayout\u003c/span\u003e \u003cspan class=\"err\"\u003econversion\u003c/span\u003e \u003cspan class=\"k\"\u003eand\u003c/span\u003e \u003cspan class=\"err\"\u003eview\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%2\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%ptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mask\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%4\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%btile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%idx\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%5\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eConvert\u003c/span\u003e \u003cspan class=\"err\"\u003elayout\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e \u003cspan class=\"err\"\u003ecores\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%10\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%bcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%11\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%5\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%12\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eDOT\u003c/span\u003e \u003cspan class=\"err\"\u003ewith\u003c/span\u003e \u003cspan class=\"err\"\u003einput\u003c/span\u003e \u003cspan class=\"err\"\u003eallowances\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%20\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%10\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%11\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%12\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTMA\u003c/span\u003e \u003cspan class=\"err\"\u003edescriptor\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%25\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.make_tiled_tma_desc\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%memref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003etmaIdx\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!tma.desc\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eASYNC\u003c/span\u003e \u003cspan class=\"err\"\u003ePIPELINE\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eproducer-consumer\u003c/span\u003e \u003cspan class=\"err\"\u003emodel\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ePipeline\u003c/span\u003e \u003cspan class=\"k\"\u003eand\u003c/span\u003e \u003cspan class=\"err\"\u003eiterator\u003c/span\u003e \u003cspan class=\"err\"\u003ecreation\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_pipeline\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_pipeline\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%32\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%33\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eAgent\u003c/span\u003e \u003cspan class=\"k\"\u003eswitch\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003eproducer/consumer\u003c/span\u003e \u003cspan class=\"err\"\u003eroles\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.agent_switch\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%32\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%33\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTensor\u003c/span\u003e \u003cspan class=\"err\"\u003eallocation\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e\u003cspan class=\"err\"\u003e-buffering\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%40\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.alloc_tensor\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e128\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"m\"\u003e64\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"err\"\u003ebf16\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%41\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.alloc_tensor\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e64\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"m\"\u003e128\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"err\"\u003ebf16\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eSlice\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.extract_slice\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%40\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%51\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.insert_slice\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%data\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%40\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c64\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"nl\"\u003ePRODUCER:\u003c/span\u003e \u003cspan class=\"k\"\u003eacquire\u003c/span\u003e \u003cspan class=\"err\"\u003e→\u003c/span\u003e \u003cspan class=\"err\"\u003ewrite\u003c/span\u003e \u003cspan class=\"err\"\u003e→\u003c/span\u003e \u003cspan class=\"err\"\u003ecommit\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_acquire\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%61\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_write\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%60\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%62\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%51\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%ptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mask\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c16\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!async\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.yield\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%62\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!async\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_commit\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%30\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%61\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"nl\"\u003eCONSUMER:\u003c/span\u003e \u003cspan class=\"err\"\u003ewait\u003c/span\u003e \u003cspan class=\"err\"\u003e→\u003c/span\u003e \u003cspan class=\"err\"\u003eread\u003c/span\u003e \u003cspan class=\"err\"\u003e→\u003c/span\u003e \u003cspan class=\"k\"\u003erelease\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%70\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_wait\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%33\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!iter\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%71\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%72\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_read\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%70\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%73\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.copy\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%buf\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.yield\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%73\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_release\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%31\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%71\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!stage\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eMatrix\u003c/span\u003e \u003cspan class=\"err\"\u003emultiply\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e100\u003c/span\u003e\u003cspan class=\"err\"\u003e+\u003c/span\u003e \u003cspan class=\"err\"\u003eops\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003etiled\u003c/span\u003e \u003cspan class=\"err\"\u003eGEMM\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%80\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%72\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%acc\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%81\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%50\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%72\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%80\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTMA\u003c/span\u003e \u003cspan class=\"k\"\u003eload\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%90\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.tiled_tma_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%btile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%buf\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%25\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%idx\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c64\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"nv\"\u003e!tma.desc\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e!async\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eOutput\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%100\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.insert_slice\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%result\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%41\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%101\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%100\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%102\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%101\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"nvvm--llvm\"\u003eNVVM + LLVM\u003c/h2\u003e\n\n\u003cp\u003eAfter \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e, the compiler lowers to NVVM (NVIDIA’s LLVM dialect) and then to standard LLVM IR.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eKey NVVM intrinsics:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.mma.sync.*\u003c/code\u003e - Tensor core matrix multiply\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.ldmatrix.*\u003c/code\u003e - Load matrix fragments from shared memory\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.cp.async.*\u003c/code\u003e - Asynchronous memory copy\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.bar.warp.sync\u003c/code\u003e - Warp-level synchronization\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.tcgen05.*\u003c/code\u003e - Tensor core intrinsics\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see NVVM/LLVM IR key sections\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"c1\"\u003e; Thread ID and warp-level operations\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%233\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"err\"\u003erange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"m\"\u003e1024\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.read.ptx.sreg.tid.x\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%234\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003eicmp\u003c/span\u003e \u003cspan class=\"k\"\u003eeq\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%233\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%235\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003eashr\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%233\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"m\"\u003e5\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%236\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.shfl.sync.idx.i32\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e-1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%235\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e31\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%237\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.elect.sync\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e-1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Mbarrier initialization (async pipeline synchronization)\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mbarrier.init.shared\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003egetelementptr\u003c/span\u003e \u003cspan class=\"k\"\u003einbounds\u003c/span\u003e \u003cspan class=\"k\"\u003enuw\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei8\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"vg\"\u003e@global_smem\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei64\u003c/span\u003e \u003cspan class=\"m\"\u003e82000\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%241\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mbarrier.init.shared\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003egetelementptr\u003c/span\u003e \u003cspan class=\"k\"\u003einbounds\u003c/span\u003e \u003cspan class=\"k\"\u003enuw\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei8\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"vg\"\u003e@global_smem\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei64\u003c/span\u003e \u003cspan class=\"m\"\u003e82008\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%241\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Cluster-wide fence and barrier\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"k\"\u003easm\u003c/span\u003e \u003cspan class=\"k\"\u003esideeffect\u003c/span\u003e \u003cspan class=\"s\"\u003e\"fence.mbarrier_init.release.cluster;\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"s\"\u003e\"n\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.barrier.cta.sync.aligned.all\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Async copy from global to shared memory (cp.async)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%1478\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003eselect\u003c/span\u003e \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1459\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.cp.async.cg.shared.global.16.s\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1477\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1451\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1478\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.cp.async.cg.shared.global.16.s\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1485\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1452\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1486\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Signal mbarrier arrival after async copy\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.cp.async.mbarrier.arrive.noinc.shared\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1535\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; TCGEN05 tensor core intrinsics\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; Allocate tensor memory\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%tmem\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.tcgen05.alloc\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e65536\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Load data into tensor memory\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.tcgen05.ld\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%tmem\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%smem_ptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%size\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Execute TCGEN05 MMA (128x256x64 tile)\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.tcgen05.mma\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%tmem_a\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%tmem_b\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%tmem_c\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Fence and wait for tensor core completion\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.tcgen05.fence\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.tcgen05.wait\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"sass\"\u003eSASS\u003c/h2\u003e\n\n\u003cp\u003eThe final output is SASS.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eKey SASS instructions:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eHMMA.16816.F32.BF16\u003c/code\u003e - Half-precision matrix multiply-accumulate\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eTCGEN05.MMA\u003c/code\u003e - Tensor core MMA\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eTCGEN05.LD.S\u003c/code\u003e - Tensor memory load\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eUTCPMULTI\u003c/code\u003e / \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eLDG\u003c/code\u003e - Global memory loads\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eSYNCS.EXCH\u003c/code\u003e - Async synchronization exchange\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eFENCE.VIEW.ASYNC.S\u003c/code\u003e - Async memory fence\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see SASS key sections\u003c/summary\u003e\n\n \u003cdiv class=\"language-nasm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"c1\"\u003e; SASS - MoE kernel (fused_moe_kernel)\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; Target: sm_120a\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Thread ID and CTA setup\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0020\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eS2R\u003c/span\u003e \u003cspan class=\"nv\"\u003eR0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eSR_TID.X\u003c/span\u003e \u003cspan class=\"c1\"\u003e; ; Get thread ID\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0060\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eS2UR\u003c/span\u003e \u003cspan class=\"nv\"\u003eUR8\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eSR_CgaCtaId\u003c/span\u003e \u003cspan class=\"c1\"\u003e; ; Get CTA ID (uniform reg)\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Async fence and mbarrier sync (cluster sync)\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0110\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eFENCE.VIEW.ASYNC.S\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0120\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eSYNCS.EXCH.64\u003c/span\u003e \u003cspan class=\"nv\"\u003eURZ\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eUR8\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mh\"\u003e0x14050\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"nv\"\u003eUR4\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0130\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eSYNCS.EXCH.64\u003c/span\u003e \u003cspan class=\"nv\"\u003eURZ\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eUR8\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mh\"\u003e0x14058\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"nv\"\u003eUR4\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e0140\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nf\"\u003eSYNCS.EXCH.64\u003c/span\u003e \u003cspan class=\"nv\"\u003eURZ\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eUR8\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mh\"\u003e0x14060\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"nv\"\u003eUR6\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; ... (data loading, address calculation) ...\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Tensor core HMMA - 16x8x16 BF16→F32 matrix multiply\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; R156 = A matrix fragment (reused across 7 HMMAs)\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; R124,R120,R116,R112,R108,R104,R100 = B matrix fragments\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e; R200,R204,R64,R60,R56,R52,R48 = accumulator tiles\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea00\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR200\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR124\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR200\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea10\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR204\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR120\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR204\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea20\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR64\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR116\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR64\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea30\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR60\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR112\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR60\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea40\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR56\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR108\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR56\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea50\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR52\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR104\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR52\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea60\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR48\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR156\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR100\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR48\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Second A fragment (R148) with different B fragments\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea70\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR200\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR148\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR126\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR200\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea80\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR204\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR148\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR122\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR204\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/*\u003c/span\u003e\u003cspan class=\"err\"\u003e4\u003c/span\u003e\u003cspan class=\"nf\"\u003ea90\u003c/span\u003e\u003cspan class=\"o\"\u003e*/\u003c/span\u003e \u003cspan class=\"nv\"\u003eHMMA.16816.F32.BF16\u003c/span\u003e \u003cspan class=\"nv\"\u003eR64\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR148\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR118\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eR64\u003c/span\u003e \u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"the-tileir-passes\"\u003eThe TileIR passes\u003c/h1\u003e\n\n\u003cp\u003eTileIR runs multiple passes to transform your code. The passes are grouped by the scope they operate on:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/pass_flow.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eTileIR pass pipeline\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2026-01-29/pass_glossary.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eDetailed pass pipeline: cuda_tile.entry → nv_tileaa.func (×12) → builtin.module → gpu.module\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003chr /\u003e\n\n\u003ch3 id=\"pass-1-cuda_tileentry\"\u003ePass 1: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.entry\u003c/code\u003e\u003c/h3\u003e\n\n\u003cp\u003eEntry point canonicalization—validates kernel structure, emits compile-time constants for tile sizes/strides, propagates input constraints via \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eassume\u003c/code\u003e operations, creates tensor views, and establishes memory ordering via \u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_token\u003c/code\u003e.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch3 id=\"pass-2-nv_tileaafunc-12-iterations\"\u003ePass 2: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.func\u003c/code\u003e (×12 iterations)\u003c/h3\u003e\n\n\u003cp\u003eIterative lowering from cuda_tile to nv_tileaa. First iteration converts \u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_tensor_view\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003emake_memref\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eget_tile_block_id\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eget_program_id\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003emmaf\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003edot\u003c/code\u003e, decomposes \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eload_view_tko\u003c/code\u003e into \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eblock_tile\u003c/code\u003e + \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etiled_load\u003c/code\u003e + \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eview\u003c/code\u003e. Subsequent iterations perform refinement and optimization. Final iteration emits precision conversions (\u003ccode class=\"language-plaintext highlighter-rouge\"\u003efp_to_fp\u003c/code\u003e), adds kernel metadata, and prepares for async pipeline lowering.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch3 id=\"pass-3-builtinmodule\"\u003ePass 3: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ebuiltin.module\u003c/code\u003e\u003c/h3\u003e\n\n\u003cp\u003eModule-level transforms and nv_tileas emission—creates async pipeline operations, software pipelines for overlapping compute/memory, producer-consumer synchronization, TMA descriptors, and double buffers.\u003c/p\u003e\n\n\u003chr /\u003e\n\n\u003ch3 id=\"pass-4-gpumodule\"\u003ePass 4: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003egpu.module\u003c/code\u003e\u003c/h3\u003e\n\n\u003cp\u003eFinal lowering to NVVM/LLVM—converts \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas.dot\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.mma.sync\u003c/code\u003e, lowers async ops to barrier/fence instructions, converts memory ops to NVVM intrinsics (\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eldmatrix\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecp.async\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003embarrier.*\u003c/code\u003e), and emits address space annotations.\u003c/p\u003e\n\n\u003ch2 id=\"complete-pass-catalog\"\u003eComplete Pass Catalog\u003c/h2\u003e\n\n\u003cp\u003eBelow is a catalog of passes that run within the TileIR pipeline.\u003c/p\u003e\n\n\u003ch3 id=\"conversion-passes\"\u003eConversion Passes\u003c/h3\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"conversion-passes-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"conversion-passes-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Pass Name\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Source\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Target\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Description\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-cudatile-to-tileaa\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecuda_tile\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileaa\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFrontend: CuTile DSL to TileAA abstract assembly\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-tileaa-to-tileas\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileaa\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileas\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMiddle-end: Abstract to scheduled assembly\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-nv-tileas-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tileas\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eBackend: TileAS to LLVM IR\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-nv-tile-func-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tile\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row3-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eConvert tile function ops to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-gpu-to-nvvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003egpu\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003envvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row4-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGPU dialect to NVVM intrinsics\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-scf-to-cf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003escf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row5-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStructured control flow to basic blocks\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env-tile-ir-convert-target-to-nvvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003env_tile\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003envvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row6-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTarget-specific ops to NVVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-pipeline-to-nvvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003epipeline\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003envvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row7-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAsync pipeline ops to NVVM barriers\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row8-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-arith-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row8-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003earith\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row8-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row8-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eArithmetic operations to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row9-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-cf-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row9-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecf\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row9-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row9-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eControl flow to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row10-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row10-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e*\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row10-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row10-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGeneric catch-all LLVM conversion\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row11-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-math-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row11-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emath\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row11-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row11-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMath operations to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row12-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-nvvm-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row12-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003envvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row12-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row12-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNVVM intrinsics to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row13-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-ub-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row13-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eub\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row13-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row13-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eUndefined behavior ops to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row14-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-vector-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row14-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003evector\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row14-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row14-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eVector ops to LLVM\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row15-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003econvert-debuginfo-to-llvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row15-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003edebug\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row15-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellvm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"conversion-passes-table-row15-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDebug info to LLVM metadata\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003ch3 id=\"tileas-optimization-passes\"\u003eTileAS Optimization Passes\u003c/h3\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"tileas-passes-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-8 overflow-x-auto\"\u003e\n \n \u003ctable id=\"tileas-passes-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Pass Name\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Description\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-assign-dot-layouts\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAssign optimal data layouts for dot (MMA) operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-assign-pipeline-layouts\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAssign layouts for async pipeline stages\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-assign-load-store-layouts\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAssign layouts for memory operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-attach-tma-desc-args\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAttach TMA descriptor arguments to kernel signature\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-dynamic-persistent\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eEnable dynamic persistent kernel execution\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-insert-OCG-knobs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eInsert Online Code Generation tuning knobs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-legalize-tmem-copy\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLegalize tensor memory copy operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-plan-cta\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePlan CTA (thread block) configuration\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row8-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-remove-buffer-alias\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row8-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRemove buffer aliasing for optimization\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row9-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-remove-dead-args\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row9-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDead argument elimination\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row10-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-remove-layout-conversions\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row10-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRemove unnecessary layout conversions\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row11-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-resolve-agent-boundary\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row11-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eResolve warp specialization agent boundaries\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row12-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-slicing\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row12-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTensor slicing for pipelining\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row13-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-materialize-async\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row13-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMaterialize async load/store/dot operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row14-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-materialize-convert-layout\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row14-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMaterialize layout conversion copy atoms\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row15-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-materialize-schedule\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row15-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMaterialize schedule to warp-specialized IR\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row16-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-unroll-register-loops\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row16-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eUnroll loops at register level\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row17-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-unspecialized-pipeline\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row17-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHandle non-warp-specialized pipelines\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row18-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-optimize-alloc-tensor\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row18-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eOptimize tensor allocation placement\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row19-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-optimize-reduce\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row19-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eOptimize reduction operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row20-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-recompute-for-scheduling\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row20-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRecompute values for better scheduling\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row21-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-legalize-fma-dot\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row21-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLegalize FMA in dot products\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row22-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-legalize-reduce\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row22-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLegalize reduction operations\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row23-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-slice-and-fuse\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row23-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSlice and fuse operations for locality\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row24-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-refine-atom-by-resource\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row24-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRefine copy atoms based on resource constraints\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row25-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-generate-schedule\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row25-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGenerate execution schedule (Serial or CostBased)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row26-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-prepare-for-scheduling\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row26-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePrepare IR for scheduling pass\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row27-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-optimize-dot-accumulation\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row27-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eOptimize dot product accumulation\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row28-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003elower-tma-load-store-to-async\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row28-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLower TMA ops to async variants\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row29-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etileas-print-decomposed-tv-layout\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"tileas-passes-table-row29-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDebug: print decomposed tensor view layouts\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003chr /\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eConversion Patterns Registered\u003c/summary\u003e\n\n \u003cp\u003eThe TileAA→TileAS conversion registers 20+ patterns:\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eTileAAToTileASTiledLoadOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Tiled load conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASDotOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Dot product conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASExtractOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Extraction conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASBroadcastOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Broadcast conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASGatherLoadOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Gather load conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASScatterStoreOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Scatter store conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASExpandDimsOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Dimension expansion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASExtractSliceOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Slice extraction\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASGenerateOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Generate conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASLoadOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Load conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASPermuteOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Permute conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASReduceOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Reduce conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASScanOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Scan conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASStoreOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Store conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASTiledAtomicRMWOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Atomic RMW conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASTiledStoreOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Tiled store conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASViewOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// View conversion\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTileAAToTileASYieldOpPattern\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Yield conversion\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"conclusion\"\u003eConclusion\u003c/h1\u003e\n\n\u003cp\u003eTileIR is a sophisticated MLIR-based compiler that progressively lowers high-level tensor operations to optimized GPU machine code. It’s an interesting piece of software that combines MLIR and the rest of NVIDIA’s toolchain to make the tile abstraction work.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eResources:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003ca href=\"https://github.com/NVIDIA/cutile-python\"\u003eCuTile Python\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"https://github.com/NVIDIA/cuda-tile\"\u003eCUDA Tile\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"https://docs.nvidia.com/cuda/tile-ir/\"\u003eNVIDIA TileIR Documentation\u003c/a\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"appendix-tileir-passes-reference\"\u003eAppendix: TileIR Passes Reference\u003c/h1\u003e\n\n\u003cp\u003eThis appendix documents the TileIR-specific passes in the compilation pipeline. Passes are organized into categories: \u003cstrong\u003eConversion\u003c/strong\u003e and \u003cstrong\u003eTileAS Optimization\u003c/strong\u003e\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eConversion Passes (16)\u003c/summary\u003e\n\n \u003cp\u003eConversion passes transform IR between MLIR dialects.\u003c/p\u003e\n\n \u003ch3 id=\"convert-cudatile-to-tileaa\"\u003econvert-cudatile-to-tileaa\u003c/h3\u003e\n\n \u003cp\u003eConverts high-level \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile\u003c/code\u003e dialect to \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e.\u003c/p\u003e\n\n \u003cp\u003e\u003cstrong\u003eKey transformations:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.mmaf\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.dot\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.load_view_tko\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.tiled_load\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.store_ptr_tko\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.tiled_store\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecuda_tile.for\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003escf.for\u003c/code\u003e + \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.yield\u003c/code\u003e\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eConvertCudaTileToTileAA\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eModuleOp\u003c/span\u003e \u003cspan class=\"k\"\u003emodule\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003eConversionTarget\u003c/span\u003e \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egetContext\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eaddLegalDialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNVTileAADialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eaddIllegalDialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003ecuda_tile\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eCudaTileDialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eRewritePatternSet\u003c/span\u003e \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003egetContext\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Register 20+ conversion patterns\u003c/span\u003e\n \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eadd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eConvertMmafToDot\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eadd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eConvertLoadViewTko\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eadd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eConvertStorePtr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eapplyPartialConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003emodule\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"convert-tileaa-to-tileas\"\u003econvert-tileaa-to-tileas\u003c/h3\u003e\n\n \u003cp\u003eMain middle-end conversion: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e (Tile Assembly).\u003c/p\u003e\n\n \u003cp\u003e\u003cstrong\u003eKey transformations:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.tiled_load\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas.async_load\u003c/code\u003e + pipeline ops\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileaa.dot\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas.dot\u003c/code\u003e with layout annotations\u003c/li\u003e\n \u003cli\u003eInserts shared memory allocations\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eConvertTileAAToTileAS\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Walk all tileaa operations\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eTiledLoadOp\u003c/span\u003e \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create async copy with TMA descriptor\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003easyncCopy\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileas\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eAsyncCopyOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Allocate shared memory buffer\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003esmemAlloc\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileas\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eAllocSharedOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Convert to tileas.dot with layout attributes\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003etiledDot\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003env_tileas\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"n\"\u003etiledDot\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"lhs_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eselectMMALayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(...));\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"convert-nv-tileas-to-llvm\"\u003econvert-nv-tileas-to-llvm\u003c/h3\u003e\n\n \u003cp\u003eBackend code generation: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003env_tileas\u003c/code\u003e → LLVM IR with NVVM intrinsics.\u003c/p\u003e\n\n \u003cp\u003e\u003cstrong\u003eKey transformations:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileas.tcgen05_mma\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.tcgen05.mma.*\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileas.tcgen05_ld\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.tcgen05.ld.*\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileas.async_copy\u003c/code\u003e → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.cp.async.*\u003c/code\u003e\u003c/li\u003e\n \u003cli\u003eBarrier ops → \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e@llvm.nvvm.barrier.*\u003c/code\u003e\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eConvertTileASToLLVM\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eModuleOp\u003c/span\u003e \u003cspan class=\"k\"\u003emodule\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eConversionTarget\u003c/span\u003e \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egetContext\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eaddLegalDialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eLLVM\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eLLVMDialect\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eRewritePatternSet\u003c/span\u003e \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003egetContext\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// MMA operations\u003c/span\u003e\n \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eadd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eTcgen05MMAToNVVM\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e([](\u003c/span\u003e\u003cspan class=\"n\"\u003etcgen05\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eMMAOp\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Generate NVVM MMA intrinsic\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eNVVM\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eTcgen05MMAOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Memory operations with TMA\u003c/span\u003e\n \u003cspan class=\"n\"\u003epatterns\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eadd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eTcgen05LoadToNVVM\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e([](\u003c/span\u003e\u003cspan class=\"n\"\u003etcgen05\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eLoadOp\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eNVVM\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eTcgen05LoadOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(...);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"convert-gpu-to-nvvm\"\u003econvert-gpu-to-nvvm\u003c/h3\u003e\n\n \u003cp\u003eConverts GPU dialect operations to NVVM intrinsics.\u003c/p\u003e\n\n \u003ctable\u003e\n \u003cthead\u003e\n \u003ctr\u003e\n \u003cth\u003eGPU Op\u003c/th\u003e\n \u003cth\u003eNVVM Intrinsic\u003c/th\u003e\n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003egpu.thread_id\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.read.ptx.sreg.tid.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003egpu.block_id\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.read.ptx.sreg.ctaid.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003egpu.block_dim\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.read.ptx.sreg.ntid.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003egpu.barrier\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.barrier0\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003c/tbody\u003e\n \u003c/table\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"convert-pipeline-to-nvvm\"\u003econvert-pipeline-to-nvvm\u003c/h3\u003e\n\n \u003cp\u003eConverts async pipeline operations to NVVM barrier intrinsics.\u003c/p\u003e\n\n \u003ctable\u003e\n \u003cthead\u003e\n \u003ctr\u003e\n \u003cth\u003ePipeline Op\u003c/th\u003e\n \u003cth\u003eNVVM Op\u003c/th\u003e\n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003epipeline.producer_acquire\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.mbarrier.arrive.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003epipeline.producer_commit\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.mbarrier.arrive.*\u003c/code\u003e + phase\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003epipeline.consumer_wait\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.mbarrier.wait.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003epipeline.consumer_release\u003c/code\u003e\u003c/td\u003e\n \u003ctd\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003envvm.mbarrier.arrive.*\u003c/code\u003e\u003c/td\u003e\n \u003c/tr\u003e\n \u003c/tbody\u003e\n \u003c/table\u003e\n\n \u003chr /\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eTileAS Optimization Passes (30)\u003c/summary\u003e\n\n \u003cp\u003eTileAS passes optimize and schedule tile operations.\u003c/p\u003e\n\n \u003ch3 id=\"tileas-assign-dot-layouts\"\u003etileas-assign-dot-layouts\u003c/h3\u003e\n\n \u003cp\u003eAssigns MMA-compatible layouts to dot product operands.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eAssignDotLayouts\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003elhsType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetLhs\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erhsType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetRhs\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Select MMA shape based on types\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMMAShape\u003c/span\u003e \u003cspan class=\"n\"\u003emmaShape\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eselectMMAShape\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elhsType\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erhsType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Assign layouts for operands\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLayout\u003c/span\u003e \u003cspan class=\"n\"\u003elhsLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeLhsLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003emmaShape\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elhsType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLayout\u003c/span\u003e \u003cspan class=\"n\"\u003erhsLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeRhsLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003emmaShape\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erhsType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"lhs_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elhsLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"rhs_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erhsLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003cp\u003e\u003cstrong\u003eMMA shapes:\u003c/strong\u003e \u003ccode class=\"language-plaintext highlighter-rouge\"\u003em16n8k16\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003em16n16k16\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003em64n256k64\u003c/code\u003e\u003c/p\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-assign-load-store-layouts\"\u003etileas-assign-load-store-layouts\u003c/h3\u003e\n\n \u003cp\u003eOptimizes memory access patterns for coalesced loads.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eAssignLoadStoreLayouts\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eLoadOp\u003c/span\u003e \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003etensorType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetResult\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Check for TMA opportunity\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecanUseTMA\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLayout\u003c/span\u003e \u003cspan class=\"n\"\u003etmaLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeTMALayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etensorType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etmaLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"use_tma\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Vectorized load layout\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLayout\u003c/span\u003e \u003cspan class=\"n\"\u003evecLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeVectorizedLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etensorType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003evecLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-assign-pipeline-layouts\"\u003etileas-assign-pipeline-layouts\u003c/h3\u003e\n\n \u003cp\u003eAssigns layouts for async pipeline buffers.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eAssignPipelineLayouts\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003ePipelineOp\u003c/span\u003e \u003cspan class=\"n\"\u003epipelineOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003estage\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003epipelineOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetStages\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Assign shared memory layouts for buffers\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003ebuffer\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003estage\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetBuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLayout\u003c/span\u003e \u003cspan class=\"n\"\u003esmemLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeSwizzledLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ebuffer\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebuffer\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003esmemLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-generate-schedule\"\u003etileas-generate-schedule\u003c/h3\u003e\n\n \u003cp\u003eGenerates execution schedule using cost-based or serial scheduler.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eGenerateSchedule\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Build dependency graph\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDependencyGraph\u003c/span\u003e \u003cspan class=\"n\"\u003edepGraph\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Select scheduler based on options\u003c/span\u003e\n \u003cspan class=\"n\"\u003eScheduler\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003escheduler\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003euseCostBasedScheduler\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003escheduler\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003eCostBasedScheduler\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edepGraph\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003escheduler\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003eSerialScheduler\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edepGraph\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Generate schedule\u003c/span\u003e\n \u003cspan class=\"n\"\u003eSchedule\u003c/span\u003e \u003cspan class=\"n\"\u003eschedule\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003escheduler\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003egenerateSchedule\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Apply schedule to IR\u003c/span\u003e\n \u003cspan class=\"n\"\u003eapplySchedule\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eschedule\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003cp\u003e\u003cstrong\u003eScheduler types:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eSerial\u003c/code\u003e: Topological order\u003c/li\u003e\n \u003cli\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eCostBased\u003c/code\u003e: Latency-aware with heuristics\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-materialize-schedule\"\u003etileas-materialize-schedule\u003c/h3\u003e\n\n \u003cp\u003eMaterializes abstract schedule into warp-specialized IR.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMaterializeSchedule\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eSchedule\u003c/span\u003e \u003cspan class=\"n\"\u003eschedule\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetSchedule\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eschedule\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetStrategy\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"n\"\u003eStrategy\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eWarpSpecialize\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Split into producer/consumer\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003eproducerOps\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003econsumerOps\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003epartitionOps\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eschedule\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Create agent regions\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecreateAgentRegion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eproducerOps\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eAgentRole\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eProducer\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecreateAgentRegion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econsumerOps\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eAgentRole\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eConsumer\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Insert synchronization\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertBarriers\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eschedule\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-materialize-async\"\u003etileas-materialize-async\u003c/h3\u003e\n\n \u003cp\u003eCreates async pipeline structure with multi-buffering.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMaterializeAsync\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOption\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"num-stages\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003escf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eForOp\u003c/span\u003e \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecanPipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create N buffers for N-stage pipeline\u003c/span\u003e\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003ebuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eallocateBuffer\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Transform loop body\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitPrologue\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitSteadyState\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitEpilogue\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebuffers\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-materialize-convert-layout\"\u003etileas-materialize-convert-layout\u003c/h3\u003e\n\n \u003cp\u003eExpands layout conversions to actual data movement.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMaterializeConvertLayout\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eConvertLayoutOp\u003c/span\u003e \u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSource\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003edstLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetResult\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Generate shuffle or shared memory path\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecanUseShuffles\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003edstLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitShuffleConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitSharedMemoryConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-attach-tma-desc-args\"\u003etileas-attach-tma-desc-args\u003c/h3\u003e\n\n \u003cp\u003eInjects TMA descriptor arguments into kernel signatures.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eAttachTMADescArgs\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eTMAOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003etmaOps\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eusesTMA\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"n\"\u003etmaOps\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003etmaOp\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003etmaOps\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create TMA descriptor type\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003edescType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eTMADescriptorType\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eget\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"n\"\u003etmaOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetShape\u003c/span\u003e\u003cspan class=\"p\"\u003e(),\u003c/span\u003e\n \u003cspan class=\"n\"\u003etmaOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetElementType\u003c/span\u003e\u003cspan class=\"p\"\u003e(),\u003c/span\u003e\n \u003cspan class=\"n\"\u003etmaOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSwizzle\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Add to function arguments\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003einsertArgument\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edescType\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"s\"\u003e\"tma_desc\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-slicing\"\u003etileas-slicing\u003c/h3\u003e\n\n \u003cp\u003eSlices tensors for pipelined execution.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASSlicing\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eLoadOp\u003c/span\u003e \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003etensorType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetResult\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003esliceDim\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetSliceDimension\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003esliceSize\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeSliceSize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etensorType\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003esliceDim\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Replace single load with sliced loads\u003c/span\u003e\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003eslices\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003enumSlices\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eslice\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecreate\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eSlicedLoadOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"n\"\u003eloadOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSource\u003c/span\u003e\u003cspan class=\"p\"\u003e(),\u003c/span\u003e \u003cspan class=\"n\"\u003esliceDim\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003esliceSize\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003esliceSize\u003c/span\u003e\n \u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eslices\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eslice\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-plan-cta\"\u003etileas-plan-cta\u003c/h3\u003e\n\n \u003cp\u003ePlans CTA (thread block) configuration.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003ePlanCTA\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Analyze resource requirements\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003esmemRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eanalyzeSharedMemory\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eregsRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eanalyzeRegisters\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Compute optimal CTA shape\u003c/span\u003e\n \u003cspan class=\"n\"\u003eCTAConfig\u003c/span\u003e \u003cspan class=\"n\"\u003econfig\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeCTAConfig\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"n\"\u003esmemRequired\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eregsRequired\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etargetOccupancy\u003c/span\u003e\n \u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"cta_shape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003econfig\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003etoAttribute\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-resolve-agent-boundary\"\u003etileas-resolve-agent-boundary\u003c/h3\u003e\n\n \u003cp\u003eResolves data flow across warp specialization boundaries.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eResolveAgentBoundary\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eAgentSwitchOp\u003c/span\u003e \u003cspan class=\"n\"\u003eswitchOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Identify values crossing boundary\u003c/span\u003e\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003ecrossingValues\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e \u003cspan class=\"n\"\u003ev\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eswitchOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetOperands\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecrossesBoundary\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eswitchOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecrossingValues\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Insert shared memory communication\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e \u003cspan class=\"n\"\u003ev\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003ecrossingValues\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertSharedMemoryTransfer\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eswitchOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-remove-buffer-alias\"\u003etileas-remove-buffer-alias\u003c/h3\u003e\n\n \u003cp\u003eRemoves buffer aliasing using fixed-point iteration.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRemoveBufferAlias\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003echanged\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003echanged\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003echanged\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003efalse\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eAllocTensorOp\u003c/span\u003e \u003cspan class=\"n\"\u003eallocOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003euse\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eallocOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetResult\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetUses\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eisAliasingUse\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003euse\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecreateNonAliasingBuffer\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003euse\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003echanged\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-remove-dead-args\"\u003etileas-remove-dead-args\u003c/h3\u003e\n\n \u003cp\u003eRemoves unused arguments from region operations.\u003c/p\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-remove-layout-conversions\"\u003etileas-remove-layout-conversions\u003c/h3\u003e\n\n \u003cp\u003eEliminates redundant layout conversions.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRemoveLayoutConversions\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eConvertLayoutOp\u003c/span\u003e \u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSource\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003edstLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetResult\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Remove identity conversions\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"n\"\u003edstLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ereplaceAllUsesWith\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSource\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"n\"\u003econvertOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eerase\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-optimize-alloc-tensor\"\u003etileas-optimize-alloc-tensor\u003c/h3\u003e\n\n \u003cp\u003eOptimizes tensor allocations through reuse and elimination.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eOptimizeAllocTensor\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLivenessAnalysis\u003c/span\u003e \u003cspan class=\"n\"\u003eliveness\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eAllocTensorOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003eallocs\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eAllocTensorOp\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e \u003cspan class=\"n\"\u003eallocs\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003ealloc\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eallocs\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Find reusable buffer\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003ereusable\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003efindReusableBuffer\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealloc\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eliveness\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ereuseBuffer\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealloc\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ereusable\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-optimize-reduce\"\u003etileas-optimize-reduce\u003c/h3\u003e\n\n \u003cp\u003eOptimizes reduction operations with warp shuffle or shared memory.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eOptimizeReduce\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eReduceOp\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ereductionSize\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetReductionSize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereductionSize\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;=\u003c/span\u003e \u003cspan class=\"mi\"\u003e32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003esetAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"s\"\u003e\"warp_shuffle\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"nf\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereductionSize\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;=\u003c/span\u003e \u003cspan class=\"mi\"\u003e1024\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003esetAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"s\"\u003e\"shared_memory\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003esetAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"s\"\u003e\"multi_stage\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-optimize-dot-accumulation\"\u003etileas-optimize-dot-accumulation\u003c/h3\u003e\n\n \u003cp\u003eOptimizes MMA accumulation patterns for better register utilization.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eOptimizeDotAccumulation\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eaccumPattern\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eanalyzeAccumulationPattern\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eswitch\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eaccumPattern\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ecase\u003c/span\u003e \u003cspan class=\"n\"\u003eAccumPattern\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eSimpleLoop\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003eoptimizeSimpleAccumulation\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ebreak\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003ecase\u003c/span\u003e \u003cspan class=\"n\"\u003eAccumPattern\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eSplitK\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003eoptimizeSplitKAccumulation\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ebreak\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003ecase\u003c/span\u003e \u003cspan class=\"n\"\u003eAccumPattern\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStreamK\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003eoptimizeStreamKAccumulation\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ebreak\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-recompute-for-scheduling\"\u003etileas-recompute-for-scheduling\u003c/h3\u003e\n\n \u003cp\u003eTrades recomputation for reduced register pressure.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASRecomputeForScheduling\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003eRegisterPressureAnalysis\u003c/span\u003e \u003cspan class=\"n\"\u003eregPressure\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e \u003cspan class=\"n\"\u003eresult\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003egetResults\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eshouldRecompute\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eregPressure\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003emarkForRecomputation\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n \u003cspan class=\"n\"\u003eapplyRecomputations\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"nf\"\u003eshouldRecompute\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e \u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eRegisterPressureAnalysis\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003erpa\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Recompute if value is cheap but keeping it live causes spills\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeCost\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eestimateComputeCost\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetDefiningOp\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003espillCost\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003erpa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eestimateSpillCost\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ev\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeCost\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003espillCost\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-legalize-fma-dot\"\u003etileas-legalize-fma-dot\u003c/h3\u003e\n\n \u003cp\u003eEnsures FMA operations match hardware capabilities.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eLegalizeFmaDot\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ehasFmaAccumulation\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003elegalizeFma\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003elegalizeFma\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eaccType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetAccumulator\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eisLegalAccumulatorType\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eaccType\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003elegalType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLegalAccumulatorType\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eaccType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertAccumulatorConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elegalType\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eisMixedPrecision\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003elegalizeMixedPrecisionFma\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edotOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-legalize-reduce\"\u003etileas-legalize-reduce\u003c/h3\u003e\n\n \u003cp\u003eEnsures reductions use supported types and sizes.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eLegalizeReduce\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eReduceOp\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eisLegalReduction\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003elegalizeReduction\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003elegalizeReduction\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eReduceOp\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003einputType\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetInput\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003egetType\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003ereductionKind\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetReductionKind\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eisSupportedElementType\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003einputType\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetElementType\u003c/span\u003e\u003cspan class=\"p\"\u003e()))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertTypeConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eisSupportedReductionSize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003einputType\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetReductionDim\u003c/span\u003e\u003cspan class=\"p\"\u003e()))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003esplitReduction\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ereduceOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-legalize-tmem-copy\"\u003etileas-legalize-tmem-copy\u003c/h3\u003e\n\n \u003cp\u003eLegalizes tensor memory (tmem) copy operations. Tensor memory is dedicated storage for tensor core operands.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASLegalizeTmemCopy\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edyn_cast\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eCopyOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003einvolvesTmem\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003elegalizeTmemCopy\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003elegalizeTmemCopy\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eCopyOp\u003c/span\u003e \u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetSource\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003edstLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetDest\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Infer register layout from tmem layout\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eregLayout\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003einferRegisterLayoutFromTmem\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Insert necessary layout conversions\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eneedsConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eregLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertLayoutConversion\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecopyOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003esrcLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eregLayout\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-slice-and-fuse\"\u003etileas-slice-and-fuse\u003c/h3\u003e\n\n \u003cp\u003eApplies loop tiling (slicing) and fusion for improved data locality.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eSliceAndFuse\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eSmallVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eFusionGroup\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003efusionGroups\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecollectFusionCandidates\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003efusionGroups\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003egroup\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003efusionGroups\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003esliceSize\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecomputeOptimalSliceSize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003esliceOperations\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003esliceSize\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuseOperations\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003efuseOperations\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFusionGroup\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create fused loop nest\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// - Single loop iterating over slices\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// - Multiple operations per slice iteration\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003efusedLoop\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecreateFusedLoop\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003egroup\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetOperations\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003emoveIntoFusedLoop\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003efusedLoop\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-refine-atom-by-resource\"\u003etileas-refine-atom-by-resource\u003c/h3\u003e\n\n \u003cp\u003eAdjusts operation granularity (“atom”) based on available hardware resources.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRefineAtomByResource\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetTargetResources\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ehasAtomAttribute\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003erefineAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003erefineAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eResourceConstraints\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003ecurrentAtom\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003esmemRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eestimateSmemUsage\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ecurrentAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eregsRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eestimateRegUsage\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ecurrentAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Refine if over resource limits (SM120: 228KB smem, 65536 regs)\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esmemRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003emaxSmem\u003c/span\u003e \u003cspan class=\"o\"\u003e||\u003c/span\u003e\n \u003cspan class=\"n\"\u003eregsRequired\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003emaxRegs\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erefinedAtom\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003efindSmallerAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eresources\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003esetAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erefinedAtom\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-prepare-for-scheduling\"\u003etileas-prepare-for-scheduling\u003c/h3\u003e\n\n \u003cp\u003eNormalizes IR and annotates operation latencies for the scheduler.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003ePrepareForScheduling\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003enormalizeLoops\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertSchedulingAnchors\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eannotateLatencies\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eidentifyBarriers\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003eannotateLatencies\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003elatency\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eestimateLatency\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"sched.latency\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetI64IntegerAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elatency\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-unroll-register-loops\"\u003etileas-unroll-register-loops\u003c/h3\u003e\n\n \u003cp\u003eUnrolls loops that access register-resident tensors (required since GPU registers cannot be dynamically indexed).\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASUnrollRegisterLoops\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003escf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eForOp\u003c/span\u003e \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eaccessesRegisterTensors\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003ecanAvoidUnroll\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Must unroll - register tensors require static indexing\u003c/span\u003e\n \u003cspan class=\"n\"\u003eunrollLoop\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"nf\"\u003eaccessesRegisterTensors\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003escf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eForOp\u003c/span\u003e \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eaccessesRegs\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003efalse\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eValue\u003c/span\u003e \u003cspan class=\"n\"\u003eoperand\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003egetOperands\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eisRegisterTensor\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eoperand\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eaccessesRegs\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eaccessesRegs\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-unspecialized-pipeline\"\u003etileas-unspecialized-pipeline\u003c/h3\u003e\n\n \u003cp\u003eImplements software pipelining without warp specialization (all warps do both load and compute).\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASUnspecializedPipeline\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOption\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"unspecialized-pipeline-num-stages\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003escf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eForOp\u003c/span\u003e \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ecanPipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eapplySoftwarePipelining\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003eapplySoftwarePipelining\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003escf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eForOp\u003c/span\u003e \u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitPrologue\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Pre-load data for first N iterations\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitSteadyState\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Overlap load(i+N) with compute(i)\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitEpilogue\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eforOp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003enumStages\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Drain remaining computations\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-dynamic-persistent\"\u003etileas-dynamic-persistent\u003c/h3\u003e\n\n \u003cp\u003eTransforms kernels into dynamic persistent kernels that process work items from a queue.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASDynamicPersistent\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003ehasAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"dynamic_persistent\"\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eemitWarning\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"Kernel is already dynamic persistent\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003etransformToPersistent\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"dynamic_persistent\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebuilder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003egetUnitAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003etransformToPersistent\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Insert outer loop that fetches work items:\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// while (workAvailable()) {\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// workItem = fetchWork();\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// processWorkItem(workItem);\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// signalCompletion();\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// }\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003chr /\u003e\n\n \u003ch3 id=\"tileas-insert-ocg-knobs\"\u003etileas-insert-OCG-knobs\u003c/h3\u003e\n\n \u003cp\u003eInserts OCG (Optimizing Code Generator) hints for the PTXAS backend.\u003c/p\u003e\n\n \u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eTileASInsertOCGKnobs\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003erunOnOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eFuncOp\u003c/span\u003e \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003egetOperation\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003efuncOp\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ewalk\u003c/span\u003e\u003cspan class=\"p\"\u003e([\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e](\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eloopOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edyn_cast\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eLoopOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertOCGDirectives\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eloopOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003emmaOp\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003edyn_cast\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eDotOp\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003einsertMMAOptimizationHints\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003emmaOp\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e});\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003einsertOCGDirectives\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eOperation\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"ocgEnterDirectives\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebuildOCGDirectives\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"cm\"\u003e/*enter=*/\u003c/span\u003e\u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esetAttr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"ocgLeaveDirectives\"\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebuildOCGDirectives\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eop\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"cm\"\u003e/*enter=*/\u003c/span\u003e\u003cspan class=\"nb\"\u003efalse\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003chr /\u003e\n\n\u003ch1 id=\"appendix-ir-dumps\"\u003eAppendix: IR Dumps\u003c/h1\u003e\n\n\u003cp\u003eThis appendix contains the IR dumps from the MoE kernel compilation. Some of the IR below uses \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e%0\u003c/code\u003e placeholders.\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003ecuda_tile IR\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003ecuda_tile\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eHigh-level\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e \u003cspan class=\"k\"\u003efrom\u003c/span\u003e \u003cspan class=\"err\"\u003eCuTile\u003c/span\u003e \u003cspan class=\"err\"\u003ePython\u003c/span\u003e \u003cspan class=\"err\"\u003eAPI\u003c/span\u003e\n\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#1\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003ecuda_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eentry\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"cuda_tile.module\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"cuda_tile.entry\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.constant\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_tensor_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_tensor_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_token\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%2\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.get_tile_block_id\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.get_tile_block_id\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.subi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.mini\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.subi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.remi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.get_tile_block_id\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.mini\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.mini\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.xori\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.andi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.xori\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.mini\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.select\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.select\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.remi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.get_tile_block_id\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.xori\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.andi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.xori\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.select\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.remi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.select\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.mini\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.iota\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.iota\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.offset\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.offset\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.make_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.load_ptr_tko\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_partition_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.make_tensor_view\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_view_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.make_partition_view\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.make_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.load_view_tko\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.divi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.iota\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.for\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.iota\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.andi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.offset\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.offset\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.make_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.make_partition_view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.make_tensor_view\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003etoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.load_view_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.make_partition_view\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.make_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003epart\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.load_view_tko\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.mmaf\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.load_ptr_tko\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"cuda_tile.continue\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.mmaf\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.divi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.iota\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.iota\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.ftof\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.for\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.load_ptr_tko\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.muli\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.exti\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.exti\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.cmpi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.andi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.addi\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.reshape\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.reshape\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.offset\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"cuda_tile.store_ptr_tko\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%cuda_tile.offset\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.ftof\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%cuda_tile.make_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eview\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ect\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"cuda_tile.return\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003env_tileaa IR\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTile-level\u003c/span\u003e \u003cspan class=\"err\"\u003eops\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003earchitecture-independent\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#1\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileaa.func\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ekernel_spec\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.assume\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_memref\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_memref\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.create_mem_token\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.get_program_id\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.get_program_id\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_range\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.addptr\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.splat\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.addptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.block_tile\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.make_memref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.block_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.tiled_load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_range\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.floordivsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.ceildivsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.addptr\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.splat\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.addptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.block_tile\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.make_memref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.convert_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.floordivsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.block_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.tiled_load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.load\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.view\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.make_range\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.extract\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.muli\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.fp_to_fp\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%scf.for\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.broadcast\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.extsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.splat\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.assume\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ememref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.addptr\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.splat\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileaa.store\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.addptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.fp_to_fp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileaa.return\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#2\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#3\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#4\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#5\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#6\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#7\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#8\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#9\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#10\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#11\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003ePass\u003c/span\u003e \u003cspan class=\"vg\"\u003e#12\u003c/span\u003e \u003cspan class=\"err\"\u003escope\u003c/span\u003e\u003cspan class=\"p\"\u003e=\u003c/span\u003e\u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eLines\u003c/span\u003e \u003cspan class=\"m\"\u003e193-352\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003efinal\u003c/span\u003e \u003cspan class=\"err\"\u003eassembly\u003c/span\u003e \u003cspan class=\"err\"\u003ewith\u003c/span\u003e \u003cspan class=\"err\"\u003efp_to_fp\u003c/span\u003e \u003cspan class=\"err\"\u003econversions\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eSee\u003c/span\u003e \u003cspan class=\"err\"\u003edump\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003ecomplete\u003c/span\u003e \u003cspan class=\"err\"\u003econtent\u003c/span\u003e \u003cspan class=\"nl\"\u003eincluding:\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"m\"\u003e32\u003c/span\u003e \u003cspan class=\"err\"\u003efp_to_fp\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003eoutput\u003c/span\u003e \u003cspan class=\"err\"\u003eprecision\u003c/span\u003e \u003cspan class=\"err\"\u003econversion\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003eMultiple\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"err\"\u003edeclarations\u003c/span\u003e \u003cspan class=\"err\"\u003ewith\u003c/span\u003e \u003cspan class=\"err\"\u003ekernel\u003c/span\u003e \u003cspan class=\"kt\"\u003emetadata\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003eFinal\u003c/span\u003e \u003cspan class=\"err\"\u003ememory\u003c/span\u003e \u003cspan class=\"err\"\u003elayout\u003c/span\u003e \u003cspan class=\"err\"\u003epreparation\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003env_tileas IR\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileas\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTile-level\u003c/span\u003e \u003cspan class=\"err\"\u003eScheduled\u003c/span\u003e \u003cspan class=\"err\"\u003eAssembly\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003earchitecture-specific\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"k\"\u003ewithin\u003c/span\u003e \u003cspan class=\"err\"\u003env_tileaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003efunc\u003c/span\u003e \u003cspan class=\"err\"\u003epass\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.addptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.cmpi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.block_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.tiled_load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.floordivsi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.expand_dims\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arith.addi\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.broadcast\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.addptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.andi\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.view\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.tiled_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.block_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.create_mem_token\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emtoken\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.view\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.tiled_load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.load\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.view\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.convert_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.convert_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.convert_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.convert_layout\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.dot\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.make_tiled_tma_desc\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.make_memref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003etmaIdx\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"k\"\u003ewithin\u003c/span\u003e \u003cspan class=\"k\"\u003ebuiltin\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"k\"\u003emodule\u003c/span\u003e \u003cspan class=\"err\"\u003epass\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_pipeline\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_pipeline\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_pipeline\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.create_iterator\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.create_pipeline\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.agent_switch\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e...)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(...)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eProducer-Consumer\u003c/span\u003e \u003cspan class=\"err\"\u003ePattern\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003erepeated\u003c/span\u003e \u003cspan class=\"err\"\u003ethroughout\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_acquire\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.inc_iter\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_write\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.producer_acquire\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.producer_commit\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.producer_write\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_wait\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%1\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_read\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.consumer_wait\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003econsumer_idx\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e \u003cspan class=\"err\"\u003eregions\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nv_tileas.async.pipeline.consumer_release\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.async.pipeline.consumer_read\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eDot\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e100\u003c/span\u003e\u003cspan class=\"err\"\u003e+\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003etiled\u003c/span\u003e \u003cspan class=\"err\"\u003ematrix\u003c/span\u003e \u003cspan class=\"err\"\u003emultiply\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.dot\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileas.extract_slice\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.extract_slice\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003erepeated\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003eall\u003c/span\u003e \u003cspan class=\"err\"\u003etile\u003c/span\u003e \u003cspan class=\"err\"\u003epartitions\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eTMA\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.make_tiled_tma_desc\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.make_memref\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003etmaIdx\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eaa\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003ebtile\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.async.tiled_tma_load\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.block_tile\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.make_tiled_tma_desc\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arg\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileaa.extract\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(...)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003e?\u003c/span\u003e\u003cspan class=\"k\"\u003etype\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eOutput\u003c/span\u003e \u003cspan class=\"err\"\u003eassembly\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e32\u003c/span\u003e \u003cspan class=\"err\"\u003einsert_slice\u003c/span\u003e \u003cspan class=\"err\"\u003efor\u003c/span\u003e \u003cspan class=\"err\"\u003eoutput\u003c/span\u003e \u003cspan class=\"err\"\u003etiles\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nv_tileas.insert_slice\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%nv_tileaa.fp_to_fp\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%nv_tileas.alloc_tensor\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%arith.constant\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003eiN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;...\u0026gt;)\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003erepeated\u003c/span\u003e \u003cspan class=\"m\"\u003e32\u003c/span\u003e \u003cspan class=\"err\"\u003etimes\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eNVVM Dialect IR\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003envvm\u003c/span\u003e \u003cspan class=\"err\"\u003edialect\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003eNVVM\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003eNVIDIA\u003c/span\u003e \u003cspan class=\"err\"\u003ePTX\u003c/span\u003e \u003cspan class=\"err\"\u003eintrinsics\u003c/span\u003e \u003cspan class=\"err\"\u003ein\u003c/span\u003e \u003cspan class=\"err\"\u003eMLIR\u003c/span\u003e \u003cspan class=\"err\"\u003eform\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eBarrier\u003c/span\u003e \u003cspan class=\"k\"\u003eand\u003c/span\u003e \u003cspan class=\"err\"\u003eFence\u003c/span\u003e \u003cspan class=\"err\"\u003eOperations\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.fence.mbarrier.init\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.barrier\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.fence.proxy\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvvm.read.ptx.sreg.clusterid.x\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvvm.read.ptx.sreg.tid.x\"\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eAsync\u003c/span\u003e \u003cspan class=\"err\"\u003eGlobal→Shared\u003c/span\u003e \u003cspan class=\"err\"\u003eCopies\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e136\u003c/span\u003e \u003cspan class=\"err\"\u003einstances\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.cp.async.shared.global\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%ptr\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%src\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%predicate\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eTensor\u003c/span\u003e \u003cspan class=\"err\"\u003eCore\u003c/span\u003e \u003cspan class=\"err\"\u003eData\u003c/span\u003e \u003cspan class=\"err\"\u003ePacking\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"m\"\u003e088\u003c/span\u003e \u003cspan class=\"err\"\u003einstances\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvvm.cvt.packfloat.f32\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%a\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%b\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mode\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003ef32\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ef32\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eMemory\u003c/span\u003e \u003cspan class=\"err\"\u003eBarriers\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e66\u003c/span\u003e \u003cspan class=\"err\"\u003einstances\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.mbarrier.init.shared\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%barrier\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%count\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.mbarrier.arrive.shared\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%barrier\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"s\"\u003e\"nvvm.mbarrier.wait.shared\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%barrier\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%phase\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eMatrix\u003c/span\u003e \u003cspan class=\"err\"\u003eLoad\u003c/span\u003e \u003cspan class=\"err\"\u003eOperations\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e512\u003c/span\u003e \u003cspan class=\"err\"\u003einstances\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvvm.ldmatrix\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%ptr\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\u003cspan class=\"err\"\u003elayout\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"err\"\u003e#nvvm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emma_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"err\"\u003erow\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003enum\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"m\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"err\"\u003evector\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e \u003cspan class=\"err\"\u003eTensor\u003c/span\u003e \u003cspan class=\"err\"\u003eCore\u003c/span\u003e \u003cspan class=\"err\"\u003eMMA\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e512\u003c/span\u003e \u003cspan class=\"err\"\u003einstances\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e===\u003c/span\u003e\n\u003cspan class=\"nv\"\u003e%0\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvvm.mma.sync\"\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003e%a\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%b\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"err\"\u003elayoutA\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"err\"\u003e#nvvm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emma_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"err\"\u003erow\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e\n \u003cspan class=\"err\"\u003elayoutB\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"err\"\u003e#nvvm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003emma_layout\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"err\"\u003ecol\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e\n \u003cspan class=\"err\"\u003eshape\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"err\"\u003e#nvvm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"err\"\u003eshape\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"err\"\u003em\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"m\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003en\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"m\"\u003e8\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ek\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"m\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"err\"\u003e:\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"err\"\u003evector\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003evector\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"err\"\u003evector\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"err\"\u003ef32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"err\"\u003evector\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003ex\u003c/span\u003e\u003cspan class=\"err\"\u003ef32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\n\u003cspan class=\"err\"\u003e//\u003c/span\u003e \u003cspan class=\"p\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"m\"\u003e977\u003c/span\u003e \u003cspan class=\"err\"\u003elines\u003c/span\u003e \u003cspan class=\"err\"\u003etotal\u003c/span\u003e \u003cspan class=\"err\"\u003e-\u003c/span\u003e \u003cspan class=\"err\"\u003etensor\u003c/span\u003e \u003cspan class=\"err\"\u003ecore\u003c/span\u003e \u003cspan class=\"err\"\u003eoperations\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ebarriers\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"err\"\u003ememory\u003c/span\u003e \u003cspan class=\"err\"\u003eops\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eLLVM IR / NVVM IR\u003c/summary\u003e\n\n \u003cdiv class=\"language-llvm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"c1\"\u003e; ModuleID = 'LLVMDialectModule'\u003c/span\u003e\n\u003cspan class=\"k\"\u003etarget\u003c/span\u003e \u003cspan class=\"k\"\u003edatalayout\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64\"\u003c/span\u003e\n\u003cspan class=\"k\"\u003etarget\u003c/span\u003e \u003cspan class=\"k\"\u003etriple\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"s\"\u003e\"nvptx64-nvidia-cuda\"\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; Kernel entry point with TMA descriptors\u003c/span\u003e\n\u003cspan class=\"k\"\u003edefine\u003c/span\u003e \u003cspan class=\"k\"\u003eptx_kernel\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@fused_moe_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%A\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e; Input tokens\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%B\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e; Expert weights\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%C\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e; Output\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%topk_weights\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%sorted_token_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%sorted_expert_ids\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"nv\"\u003e%num_token_replicas\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e \u003cspan class=\"nv\"\u003e%mul_routed_weight\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e; ... TMA descriptors appended by tileas-attach-tma-desc-args\u003c/span\u003e\n\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"vg\"\u003e#0\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n\u003cspan class=\"nl\"\u003eentry:\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e; Get cluster/block/thread IDs\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%clusterid\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.read.ptx.sreg.clusterid.x\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%tid\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"err\"\u003erange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"m\"\u003e384\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.read.ptx.sreg.tid.x\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e; Initialize barriers for async pipeline\u003c/span\u003e\n \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mbarrier.init.shared\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%barrier\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e128\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e; Async copy from global to shared memory\u003c/span\u003e\n \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.cp.async.shared.global\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%shared_dst\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"nv\"\u003e%global_src\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"m\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"c1\"\u003e; bytes\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e \u003cspan class=\"nv\"\u003e%pred\u003c/span\u003e \u003cspan class=\"c1\"\u003e; predicate\u003c/span\u003e\n \u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e; Tensor core matrix multiply\u003c/span\u003e\n \u003cspan class=\"nv\"\u003e%result\u003c/span\u003e \u003cspan class=\"p\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003ecall\u003c/span\u003e \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mma.m16n8k16.row.col.f32.f16.f16.f32\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"nv\"\u003e%a_frag\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e2\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"nv\"\u003e%b_frag\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"nv\"\u003e%c_frag\u003c/span\u003e\n \u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e; ... (full pipeline with producer/consumer synchronization)\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e; NVVM intrinsic declarations\u003c/span\u003e\n\u003cspan class=\"k\"\u003edeclare\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.read.ptx.sreg.tid.x\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"k\"\u003edeclare\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.read.ptx.sreg.clusterid.x\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"k\"\u003edeclare\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mbarrier.init.shared\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003edeclare\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.cp.async.shared.global\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"kt\"\u003eptr\u003c/span\u003e \u003cspan class=\"k\"\u003eaddrspace\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"m\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003ei1\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003edeclare\u003c/span\u003e \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"vg\"\u003e@llvm.nvvm.mma.m16n8k16.row.col.f32.f16.f16.f32\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e2\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003ei32\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;,\u003c/span\u003e \u003cspan class=\"p\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"m\"\u003e4\u003c/span\u003e \u003cspan class=\"p\"\u003ex\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003ePTX Assembly\u003c/summary\u003e\n\n \u003cdiv class=\"language-nasm highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"o\"\u003e//\u003c/span\u003e\n\u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003eGenerated\u003c/span\u003e \u003cspan class=\"nv\"\u003eby\u003c/span\u003e \u003cspan class=\"nv\"\u003eNVIDIA\u003c/span\u003e \u003cspan class=\"nv\"\u003eNVVM\u003c/span\u003e \u003cspan class=\"nv\"\u003eCompiler\u003c/span\u003e\n\u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003eCuda\u003c/span\u003e \u003cspan class=\"nv\"\u003ecompilation\u003c/span\u003e \u003cspan class=\"nv\"\u003etools\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003erelease\u003c/span\u003e \u003cspan class=\"mf\"\u003e13.1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"nv\"\u003eV13.1.80\u003c/span\u003e\n\u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003eBased\u003c/span\u003e \u003cspan class=\"nv\"\u003eon\u003c/span\u003e \u003cspan class=\"nv\"\u003eNVVM\u003c/span\u003e \u003cspan class=\"mf\"\u003e21.0\u003c/span\u003e\u003cspan class=\"nv\"\u003e.0\u003c/span\u003e\n\u003cspan class=\"o\"\u003e//\u003c/span\u003e\n\n\u003cspan class=\"nf\"\u003e.version\u003c/span\u003e \u003cspan class=\"mf\"\u003e9.1\u003c/span\u003e\n\u003cspan class=\"nf\"\u003e.target\u003c/span\u003e \u003cspan class=\"nv\"\u003esm_120a\u003c/span\u003e\n\u003cspan class=\"nf\"\u003e.address_size\u003c/span\u003e \u003cspan class=\"mi\"\u003e64\u003c/span\u003e\n\n\u003cspan class=\"nf\"\u003e.visible\u003c/span\u003e \u003cspan class=\"nv\"\u003e.entry\u003c/span\u003e \u003cspan class=\"nv\"\u003efused_moe_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.param\u003c/span\u003e \u003cspan class=\"nv\"\u003e.u64\u003c/span\u003e \u003cspan class=\"nv\"\u003e.ptr\u003c/span\u003e \u003cspan class=\"nv\"\u003e.global\u003c/span\u003e \u003cspan class=\"nv\"\u003e.align\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e \u003cspan class=\"nv\"\u003efused_moe_kernel_param_0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.param\u003c/span\u003e \u003cspan class=\"nv\"\u003e.u32\u003c/span\u003e \u003cspan class=\"nv\"\u003efused_moe_kernel_param_1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003e...\u003c/span\u003e \u003cspan class=\"mi\"\u003e31\u003c/span\u003e \u003cspan class=\"nv\"\u003eparameters\u003c/span\u003e \u003cspan class=\"nv\"\u003etotal\u003c/span\u003e \u003cspan class=\"nv\"\u003eincluding\u003c/span\u003e \u003cspan class=\"nv\"\u003eTMA\u003c/span\u003e \u003cspan class=\"nv\"\u003edescriptors\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.hidden\u003c/span\u003e \u003cspan class=\"nv\"\u003e.param\u003c/span\u003e \u003cspan class=\"nv\"\u003e.align\u003c/span\u003e \u003cspan class=\"mi\"\u003e64\u003c/span\u003e \u003cspan class=\"nv\"\u003e.b8\u003c/span\u003e \u003cspan class=\"nv\"\u003efused_moe_kernel_param_31\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e128\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"nf\"\u003e.reqntid\u003c/span\u003e \u003cspan class=\"mi\"\u003e384\u003c/span\u003e\n\u003cspan class=\"nf\"\u003e.minnctapersm\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\n\u003cspan class=\"err\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.reg\u003c/span\u003e \u003cspan class=\"nv\"\u003e.pred\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"mi\"\u003e306\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.reg\u003c/span\u003e \u003cspan class=\"nv\"\u003e.b16\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ers\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"mi\"\u003e500\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.reg\u003c/span\u003e \u003cspan class=\"nv\"\u003e.b32\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"mi\"\u003e4905\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.reg\u003c/span\u003e \u003cspan class=\"nv\"\u003e.b64\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003erd\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"mi\"\u003e348\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e80\u003c/span\u003e\u003cspan class=\"nf\"\u003eKB\u003c/span\u003e \u003cspan class=\"nv\"\u003eshared\u003c/span\u003e \u003cspan class=\"nv\"\u003ememory\u003c/span\u003e \u003cspan class=\"nv\"\u003efor\u003c/span\u003e \u003cspan class=\"nv\"\u003edouble\u003c/span\u003e \u003cspan class=\"nv\"\u003ebuffering\u003c/span\u003e\n \u003cspan class=\"nf\"\u003e.shared\u003c/span\u003e \u003cspan class=\"nv\"\u003e.align\u003c/span\u003e \u003cspan class=\"mi\"\u003e128\u003c/span\u003e \u003cspan class=\"nv\"\u003e.b8\u003c/span\u003e \u003cspan class=\"nv\"\u003eglobal_smem\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e82032\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e \u003cspan class=\"nf\"\u003eBarrier\u003c/span\u003e \u003cspan class=\"nv\"\u003eInitialization\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e\n \u003cspan class=\"nf\"\u003embarrier.init.shared.b64\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eglobal_smem\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mi\"\u003e82000\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2369\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003embarrier.init.shared.b64\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eglobal_smem\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mi\"\u003e82008\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2369\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e \u003cspan class=\"nf\"\u003eMatrix\u003c/span\u003e \u003cspan class=\"nv\"\u003eLoad\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003eldmatrix\u003c/span\u003e \u003cspan class=\"nv\"\u003efor\u003c/span\u003e \u003cspan class=\"nv\"\u003etensor\u003c/span\u003e \u003cspan class=\"nv\"\u003ecores\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e\n \u003cspan class=\"nf\"\u003eldmatrix.sync.aligned.m8n8.x4.shared.b16\u003c/span\u003e \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4645\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4646\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4647\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4648\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2789\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003eldmatrix.sync.aligned.m8n8.x4.shared.b16\u003c/span\u003e \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4649\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4650\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4651\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4652\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2793\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003eldmatrix.sync.aligned.m8n8.x4.shared.b16\u003c/span\u003e \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4653\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4654\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4655\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4656\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2797\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003eldmatrix.sync.aligned.m8n8.x4.shared.b16\u003c/span\u003e \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4657\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4658\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4659\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4660\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2801\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e512\u003c/span\u003e \u003cspan class=\"nv\"\u003eldmatrix\u003c/span\u003e \u003cspan class=\"nv\"\u003einstructions\u003c/span\u003e \u003cspan class=\"nv\"\u003etotal\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e \u003cspan class=\"nf\"\u003eTensor\u003c/span\u003e \u003cspan class=\"nv\"\u003eCore\u003c/span\u003e \u003cspan class=\"nv\"\u003eMMA\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003eHMMA\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nl\"\u003eNote:\u003c/span\u003e \u003cspan class=\"nf\"\u003esm_120a\u003c/span\u003e \u003cspan class=\"nv\"\u003euses\u003c/span\u003e \u003cspan class=\"nv\"\u003ewgmma\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"nv\"\u003etcgen05\u003c/span\u003e \u003cspan class=\"nv\"\u003einstructions\u003c/span\u003e \u003cspan class=\"nv\"\u003ein\u003c/span\u003e \u003cspan class=\"nv\"\u003eSASS\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003ePTX\u003c/span\u003e \u003cspan class=\"nv\"\u003eshows\u003c/span\u003e \u003cspan class=\"nv\"\u003ethe\u003c/span\u003e \u003cspan class=\"nv\"\u003eportable\u003c/span\u003e \u003cspan class=\"nv\"\u003emma.sync\u003c/span\u003e \u003cspan class=\"nv\"\u003eform\u003c/span\u003e\n \u003cspan class=\"nf\"\u003emma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\u003c/span\u003e\n \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nf\"\u003ef1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef3\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef4\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nf\"\u003er4645\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4646\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4647\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4648\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nf\"\u003er4709\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er4710\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"err\"\u003e{\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nf\"\u003ef1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef2\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef3\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ef4\u003c/span\u003e\u003cspan class=\"err\"\u003e}\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e512\u003c/span\u003e \u003cspan class=\"nv\"\u003emma.sync\u003c/span\u003e \u003cspan class=\"nv\"\u003einstructions\u003c/span\u003e \u003cspan class=\"nv\"\u003etotal\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e \u003cspan class=\"nf\"\u003eAsync\u003c/span\u003e \u003cspan class=\"nv\"\u003eCopy\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nv\"\u003ecp.async\u003c/span\u003e \u003cspan class=\"nv\"\u003efor\u003c/span\u003e \u003cspan class=\"nv\"\u003eglobal\u003c/span\u003e\u003cspan class=\"err\"\u003e→\u003c/span\u003e\u003cspan class=\"nv\"\u003eshared\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e\n \u003cspan class=\"nf\"\u003ecp.async.cg.shared.global\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2856\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003erd112\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"mi\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ep116\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003ecp.async.cg.shared.global\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2857\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003erd113\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"mi\"\u003e16\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ep116\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"nf\"\u003e...\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e136\u003c/span\u003e \u003cspan class=\"nv\"\u003ecp.async\u003c/span\u003e \u003cspan class=\"nv\"\u003einstructions\u003c/span\u003e \u003cspan class=\"nv\"\u003etotal\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e//\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e \u003cspan class=\"nf\"\u003eBarrier\u003c/span\u003e \u003cspan class=\"nv\"\u003eSynchronization\u003c/span\u003e \u003cspan class=\"err\"\u003e===\u003c/span\u003e\n \u003cspan class=\"nf\"\u003embarrier.arrive.shared.b64\u003c/span\u003e \u003cspan class=\"nv\"\u003e_\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eglobal_smem\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mi\"\u003e82000\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n \u003cspan class=\"nf\"\u003embarrier.try_wait.parity.shared.b64\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003ep117\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"nv\"\u003eglobal_smem\u003c/span\u003e\u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mi\"\u003e82000\u003c/span\u003e\u003cspan class=\"p\"\u003e],\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e\u003cspan class=\"nv\"\u003er2371\u003c/span\u003e\u003cspan class=\"c1\"\u003e;\u003c/span\u003e\n\u003cspan class=\"err\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch1 id=\"citation\"\u003eCitation\u003c/h1\u003e\n\n\u003cp\u003eTo cite this article:\u003c/p\u003e\n\n\u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e@article{zhu2026tileir,\n title = {NVIDIA TileIR Internals: from CuTile to MLIR/LLVM to SASS},\n author = {Zhu, Henry},\n journal = {maknee.github.io},\n year = {2026},\n month = {January},\n url = \"https://maknee.github.io/blog/2026/NVIDIA-TileIR-Internals-from-CuTile-to-MLIR-LLVM-to-SASS/\"\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e","summary":"In this post, we’ll dig deep into how TileIR works, from how it generates instructions to analyzing its different passes. We’ll trace how a Mixture-of-Experts (MoE) kernel written in CuTile gets compiled down through cuda_tile → nv_tileaa → nv_tileas → NVVM → LLVM → SASS.","date_published":"2026-01-30T06:00:00+00:00","date_modified":"2026-01-30T06:00:00+00:00","author":{"name":""},"tags":["TileIR"]},{"id":"https://maknee.github.io/blog/2026/Performance-Hints","url":"https://maknee.github.io/blog/2026/Performance-Hints/","title":"Performance Hints","content_html":"\u003c!-- \n\n--\u003e\n\n\u003cp\u003eThis post will be about going through \u003ca href=\"https://abseil.io/fast/hints.html#performance-hints\"\u003ehttps://abseil.io/fast/hints.html#performance-hints\u003c/a\u003e, a blog post written by the power duo Jeff Dean and Sanjay Ghemawat who argubly made google to what it is today. This is a knowledge distillation from the both of them with many examples from the internal codebase. Hopefully I can a thing or two professionals who have worked in the industry longer than I have been alive\u003c/p\u003e\n\n\u003ch1 id=\"reflection-after-reading-this-post\"\u003eReflection after reading this post\u003c/h1\u003e\n\n\u003cp\u003eStart at \u003ca href=\"#performance-hints\"\u003ePerformance Hints\u003c/a\u003e to see me go through the post while I’m reading through it. This short section is my takeaways from reading it.\u003c/p\u003e\n\n\u003cp\u003eTLDR, this post is about why you should build such an intuition and showing many outcomes from snippets of experience.\u003c/p\u003e\n\n\u003cp\u003eI think the intro was very very well written and puts some key points about thinking about performance into perspective.\u003c/p\u003e\n\n\u003cp\u003eThe early sections, especially in “The importance of thinking about performance” and “Estimation” provides small window into how to think about performance as a sort of life-style choice (ie, having a habit of incorporating performance before and while the project is going rather than after). The motivations for why one sometimes should think in such a manner varies, but the authors argue that down the line, you face consequences or even bigger time sinks that could have been solved in the first place (harder time spotting the issues due to complexity, time sink to communicate with people complaining about what you wrote, changing existing library for performance gains is hard, using expensive bandaids to solve performance issues).\u003c/p\u003e\n\n\u003cp\u003eEstimation has and always will be important. It’s one way to judge if your intuition is right or not (guess, run experiment, am i wrong). And most likely, for me, it’s wrong. One tricky thing to spot is if something sounds right, but is wrong. Another habit that is hard to get is the “am I wrong” part, where I get lost in the sauce of doing something and then say “I’m done and ok yeah let’s move on to the next thing” and not asking the question “was I wrong initially” to see where I went wrong in my estimations, which can trickle down to actually doing the thing properly. And I think this should apply generally to anything, but I haven’t written and measured it outside of the work I do.\u003c/p\u003e\n\n\u003cp\u003eDetailed example sections I find new to me and seemingly useful: “What to do when profiles are flat”, “Code size considerations”, “Parallelization and synchronization”, “CLs that demonstrate multiple techniques”.\u003c/p\u003e\n\n\u003ch2 id=\"side-notes-my-thoughts\"\u003eSide notes (my thoughts)\u003c/h2\u003e\n\n\u003cp\u003eOne thing now I especially now to think about is cost associated with performance. People typically talk about running services at scale and how many machines are needed for X system to run properly, but I believe what is just as important to look at is the view of a single node and its resources. These resources are repeated and scaled too. The number of cores is now 64, 128 or 256, and don’t get me even started with GPU cores. How many GB/s of memory/disk can transfer within a node? Then any improvement in compute or transfer on a single node trickles a bit down to a cloud native setting and is most likely easier to profile and debug.\u003c/p\u003e\n\n\u003cp\u003eSo…, ironically, although chips have gotten faster and faster and resoources are getting cheaper (memory, disk), and yet we still care about performance? Is it cost? Usability? Or do we face new applications that require more performance?\u003c/p\u003e\n\n\u003cp\u003eWhat about power? Power seemingly is becoming more and more of a concern with AI in the GPU/hardware space, which could result in \u003ca href=\"https://modal.com/blog/gpu-health\"\u003eerrors on the chip\u003c/a\u003e. Or was it already? I mean the main costs after building the chips, racks and datacenters are power and maintainance. It seems like the only way performance can affect power consumption is inadvertently, either through eliminating or doing less work (basically improving the algorithm). And sometimes performance gains can increase the work done (more nodes could result in less latency). So the question I’m getting at is how do we factor in is how to lower power while keeping the throughput or latency steady (something like undervolting in the gamer space where users tweak their hot GPUs to run at a way lower power while keeping 95%+ of performance).\u003c/p\u003e\n\n\u003ch1 id=\"performance-hints\"\u003ePerformance Hints\u003c/h1\u003e\n\n\u003ch2 id=\"the-importance-of-thinking-about-performance\"\u003eThe importance of thinking about performance\u003c/h2\u003e\n\n\u003cp\u003eThis section is the introduction. They both have added very insightful, yet succint sentences that makes me ponder much.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eKnuth is often quoted out of context as saying premature optimization is the root of all evil. The full quote reads: “We should forget about small efficiencies, say about 97% of the time: premature optimization is the root of all evil. Yet we should not pass up our opportunities in that critical 3%.” This document is about that critical 3%, and a more compelling quote, again from Knuth, reads:\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eIf you go to the \u003ca href=\"https://dl.acm.org/doi/pdf/10.1145/356635.356640\"\u003elink\u003c/a\u003e, Knuth actually luminates more on this “… pass up our opportunities in that critical 3 %. A good programmer will not be lulled into complacency by such reasoning, he will be wise to look carefully at the critical code; but only after that code has been identified. It is often a mistake to make a priori judgments about what parts of a program are really critical, since the universal experience of programmers who have been using measurement tools has been that their intuitive guesses fail. After working with such tools for seven years, I’ve become convinced that all compilers written from now on should be designed to provide all programmers with feedback indicating what parts of their programs are costing the most; indeed, this feedback should be supplied automatically unless it has been specificly turned off”\u003c/p\u003e\n\n\u003cp\u003eThis was published on Decemeber 01, 1974. And yet hasn’t been solved. What makes me believe that AI will solve this if after 50+ years since this has been written that AI can? What makes this such a hard problem to solve?\u003c/p\u003e\n\n\u003cp\u003eIs it “just” telling someone “hey buddy, the code written/generated here is slow?” and now tell the AI to fix it? And what makes me believe that AI will find that 3%? Or maybe that 3% doesn’t actually matter for most people, it just matters for critical pieces of code is written like postgres, mongodb? Or if you flip it, maybe the 3% matters a lot because it’s used by the 99% of people ~ xkcd image below:\u003c/p\u003e\n\n\u003cp\u003e\u003cimg src=\"https://imgs.xkcd.com/comics/dependency_2x.png\" alt=\"xkcd\" /\u003e\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMany people will say “let’s write down the code in as simple a way as possible and deal with performance later when we can profile”. However, this approach is often wrong:\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhen I first read this, I was like spot on. But doing this is difficult. How can you think of performance ahead of time? First, the question is does performance matter - always yes! it matters to some X degree, either cost, usability, etc. And second, what type of performance is needed and to what degree? Are we focused on latency for usability? reliability? cost? adability? and for each how much time do we pour and what is a good expected number to reach. these are some difficult questions to think about ahead of time without many years of experience, not just touching one project in depth, but many and each project with different goals and purposes.\u003c/p\u003e\n\n\u003cp\u003eThe baseline of knowing this is napkin math. And I still need to work on that and integrate it programs I’m working on. And I believe that is true for some things outside of just computer science. If you’re putting money into say, a stock, \u003cem\u003eideally\u003c/em\u003e you have some idea of what’s going to happen and give an educated guess. Or maybe you need to measure if you’re going to traveling with multiple people, I don’t think yoloing the trip will make a majority of people happy in most cases.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you disregard all performance concerns when developing a large system, you will end up with a flat profile where there are no obvious hotspots because performance is lost all over the place. It will be difficult to figure out how to get started on performance improvements.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is very true. One thing touches another and another and propagates. Let’s say the problem is \u003ca href=\"https://youtu.be/IxkSlnrRFqc?t=1483\"\u003eTCP window size\u003c/a\u003e\u003c/p\u003e\n\n\u003cp\u003eFor example, if you’re serving a GET request in nodejs for a website and like wow, it’s taking 1-2s from US east to west. You start adding print lines to the code to get time measurements. Hmm it seems like this fetch from db is taking a while \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eawait db.query(...)\u003c/code\u003e. maybe it’s the db. you change the query to something simple \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eawait db.query(SELECT ... COUNT 1)\u003c/code\u003e and then, oh it’s better. Then you could optimize that query and then bam, queries are ~500ms, so that’s like somewhat reasonable.\u003c/p\u003e\n\n\u003cp\u003eBut maybe you dig a little differently (not necesarily deeper). You return some dummy result instead of the db query. Oh? it’s faster? Hmm. By stroke of messing around, you try a big dummy return and you see that it’s 1-2s. What’s happening? Ask AI, etc. maybe you get TCP window size. So it’s 15kb size intially for 1RTT and the data you’re sending is like 1MB-2MB. So you have to somehow compress your data (hopefully it works) or return less data.\u003c/p\u003e\n\n\u003cp\u003eSimiliar to the gym, switching too many variables at once (like workouts) at once can make it difficult to pinpoint what’s going on.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you are developing a library that will be used by other people, the people who will run into performance problems will be likely to be people who cannot easily make performance improvements (they will have to understand the details of code written by other people/teams, and have to negotiate with them about the importance of performance optimizations).\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is the other part I’m less experienced with. I think one can get experience seeing this by working in open source or big tech where people care/have an incentive to improve a project. I wonder why others cannot easily make the perf improvements with other’s library? Many people don’t have the time or reason to look deeper, which usually doesn’t give an obvious big net benefit (not to say that it gives a net benefit at all!)?\u003c/p\u003e\n\n\u003cp\u003eI guess the question is how can you make it usable? One obvious answer is feedback. But how do you get effective feedback? Is it just talking to people who complain about it not working and trying to decipher what that means?\u003c/p\u003e\n\n\u003cp\u003eA business will face this issue with people. People don’t care about what goes on in the product. They want it to work for their specific use case cause it’s easier (cost and time) than doing it themselves.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIt is harder to make significant changes to a system when it is in heavy use.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAnother part that I’m not familiar with. Clearly either big tech or open source again is where one can see that. I guess one thing that sticks out is that you have to accomdate existing users and people and \u003cem\u003etry\u003c/em\u003e to convince them to switch. An example of this is the python2 to python3 switch. I was kind of mad that you needed to do \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eprint(...)\u003c/code\u003e instead of \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eprint ...\u003c/code\u003e because you need to type \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e(\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e)\u003c/code\u003e parens and they were kinda hard to reach physically (having to press shift + 9) compared to space.\u003c/p\u003e\n\n\u003cp\u003eAnd yet I think, for most things, it most likely has to change at one point. Not many things in life don’t change.\u003c/p\u003e\n\n\u003cp\u003eFor examples, many year friendships with people change typically one way or another.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIt is also hard to tell if there are performance problems that can be solved easily and so we end up with potentially expensive solutions like over-replication or severe overprovisioning of a service to handle load problems.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAnother area I’m not an expert in. One can guess and estimate issues, but honestly, it’s fucking hard. Real applications typically have explosions in usage typically at certain times and the \u003cem\u003elet’s solve this for now by X and vibe it with things I know\u003c/em\u003e can be just patches and not solving the actual issue and maybe you actually spend more time than necessary to solve it or more money than necessary. But identifying whether to spend that time now or later is so difficult to tell.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eInstead, we suggest that when writing code, try to choose the faster alternative if it does not impact readability/complexity of the code significantly.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eNot sure what to expect, but I will revisit these 4 key points when I’m done going through the rest\u003c/p\u003e\n\n\u003ch2 id=\"estimation\"\u003eEstimation\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you can develop an intuition for how much performance might matter in the code you are writing, you can make a more informed decision (e.g., how much extra complexity is warranted in the name of performance).\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh man, the word intution. Ugh, it’s like the best word for what it describes, but it varies per person on how they learn and build an intuition.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIs it test code? If so, you need to worry mostly about the asymptotic complexity of your algorithms and data structures. (Aside: development cycle time matters, so avoid writing tests that take a long time to run.)\nIs it code specific to an application? If so, try to figure out how much performance matters for this piece of code. This is typically not very hard: just figuring out whether code is initialization/setup code vs. code that will end up on hot paths (e.g., processing every request in a service) is often sufficient\nIs it library code that will be used by many applications? In this case it is hard to tell how sensitive it might become. This is where it becomes especially important to follow some of the simple techniques described in this document. For example, if you need to store a vector that usually has a small number of elements, use an absl::InlinedVector instead of std::vector. Such techniques are not very hard to follow and don’t add any non-local complexity to the system. And if it turns out that the code you are writing does end up using significant resources, it will be higher performance from the start. And it will be easier to find the next thing to focus on when looking at a profile.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSo my understanding is to think about what the type of work is being done for the application that you are building and to follow general good rules throughout building the project like you drinking X amount of water per day (drinking more is generally good for you for example)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eYou can do a slightly deeper analysis when picking between options with potentially different performance characteristics by relying on back of the envelope calculations. Such calculations can quickly give a very rough estimate of the performance of different alternatives, and the results can be used to discard some of the alternatives without having to implement them.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThey finally mentioned it. Ok let’s see what has changed in the last ~20 years since Jeff first mentioned this.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eHere is how such an estimation might work:\nEstimate how many low-level operations of various kinds are required, e.g., number of disk seeks, number of network round-trips, bytes transmitted etc.\nMultiply each kind of expensive operation with its rough cost, and add the results together.\nThe preceding gives the cost of the system in terms of resource usage. If you are interested in latency, and if the system has any concurrency, some of the costs may overlap and you may have to do slightly more complicated analysis to estimate the latency.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAny transfer of movement of data should be measured, then multiply by the cost (time or $), and add to get total estimated result. The following table is what every one has seen.\u003c/p\u003e\n\n\u003cdiv class=\"language-md highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003eL1 cache reference 0.5 ns\nL2 cache reference 3 ns\nBranch mispredict 5 ns\nMutex lock/unlock (uncontended) 15 ns\nMain memory reference 50 ns\nCompress 1K bytes with Snappy 1,000 ns\nRead 4KB from SSD 20,000 ns\nRound trip within same datacenter 50,000 ns\nRead 1MB sequentially from memory 64,000 ns\nRead 1MB over 100 Gbps network 100,000 ns\nRead 1MB from SSD 1,000,000 ns\nDisk seek 5,000,000 ns\nRead 1MB sequentially from disk 10,000,000 ns\nSend packet CA-\u0026gt;Netherlands-\u0026gt;CA 150,000,000 ns\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eYou may find it useful to also track estimated costs for higher-level operations relevant to your system. E.g., you might want to know the rough cost of a point read from your SQL database, the latency of interacting with a Cloud service, or the time to render a simple HTML page. If you don’t know the relevant cost of different operations, you can’t do decent back-of-the-envelope calculations!\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYeah I understood this a bit better. It’s incredibly hard to track initially because it’s hard to know what’s important and I haven’t used it consistently daily/weekly, etc.\u003c/p\u003e\n\n\u003ch3 id=\"example-time-to-quicksort-a-billion-4-byte-numbers\"\u003eExample: Time to quicksort a billion 4 byte numbers\u003c/h3\u003e\n\n\u003cp\u003eBefore looking at the answer, I would like to ask myself where would this be used and where components are mainly involved and what is the majority of the cost(bottleneck)?\u003c/p\u003e\n\n\u003cp\u003eMaybe we have many time durations (say from multiple services) and would like to plot a histogram for a webUI query for latencies.\u003c/p\u003e\n\n\u003cp\u003eComponents: memory, cpu. 1B * 4bytes = 4GB of data which is kinda tiny by today’s standard (one machine can handle it)\u003c/p\u003e\n\n\u003cp\u003eSo let’s say it’s not on disk and in memory already. The quickest is if the data is already sorted and we’re only accessing each one and writing it back to another piece of memory. So 50ns * 1B * 2 (one for read and one for write), so 5s * 2 = 10s?\u003c/p\u003e\n\n\u003cp\u003eTo be honest, it’s probably more, say all of the elements are unsorted, we would have to move all of them and then repeat on each subset as a slice like merge sort, and say without threading. So it would be like a infinite geometric series until converagance of 10s + 5s + 2.5s…, I had to search this up… s = a / (1 - r) which is 10s / (1 - 0.5) = 20s.\u003c/p\u003e\n\n\u003cp\u003eSo between 10s and 20s would be my answer.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMemory bandwidth: the array occupies 4 GB (4 bytes per number times a billion numbers). Let’s assume ~16GB/s of memory bandwidth per core. That means each pass will take ~0.25s. N is ~2^30, so we will make ~30 passes, so the total cost of memory transfer will be ~7.5 seconds.\nBranch mispredictions: we will do a total of N*log(N) comparisons, i.e., ~30 billion comparisons. Let’s assume that half of them (i.e., 15 billion) are mispredicted. Multiplying by 5 ns per misprediction, we get a misprediction cost of 75 seconds. We assume for this analysis that correctly predicted branches are free.\nAdding up the previous numbers, we get an estimate of ~82.5 seconds.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eMy answer was way off. Let me actually look at the table and try to do in their style: 1. figure out the algorithm and the counts and 2. find the components involved.\u003c/p\u003e\n\n\u003cp\u003eOk memory bandwidth is calculating the passes - 4GB/16GBs (read memory) * log(1B) (time per pass) = 0.25s * ~30 = 7.5s.\nNext is computation - branch prediction work = Nlog(N) compare = 1B*log(1B) = ~30B/2 = 15B. Then 15B * 5ns = 75s, which is surprising? I didn’t expect it to be compute bound.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLet’s assume we have a 32MB L3 cache, and that the cost of transferring data from L3 cache to the processor is negligible. The L3 cache can hold 2^23 numbers, and therefore the last 22 passes can operate on the data resident in the L3 cache (the 23rd last pass brings data into the L3 cache and the remaining passes operate on that data.) That cuts down the memory transfer cost to 2.5 seconds (10 memory transfers of 4GB at 16GB/s) instead of 7.5 seconds (30 memory transfers).\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWow… ok they talk about the caches here. 2^23 comes from… 2^20 (1MB) * 2^5 (32) / 2^2 (4 bytes per entry). So the last 22 can be loaded all in cache (N = 2^22)\u003c/p\u003e\n\n\u003ch3 id=\"example-time-to-generate-a-web-page-with-30-image-thumbnails\"\u003eExample: Time to generate a web page with 30 image thumbnails\u003c/h3\u003e\n\n\u003cp\u003eLet’s compare two potential designs where the original images are stored on disk, and each image is approximately 1MB in size.\u003c/p\u003e\n\n\u003cp\u003eTwo main compoennts: loading from disk and transferring data over the web. Total data: 30MB.\u003c/p\u003e\n\n\u003cp\u003eDisk: 30MB / 1GB/s = 0.03s. Web: log(30MB/1.5KB) * 150ms per roundtrip = 3 * 150ms = 450ms.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eRead the contents of the 30 images serially and generate a thumbnail for each one. Each read takes one seek + one transfer, which adds up to 5ms for the seek, and 10ms for the transfer, which adds up to 30 images times 15ms per image, i.e., 450ms.\nRead in parallel, assuming the images are spread evenly across K disks. The previous resource usage estimate still holds, but latency will drop by roughly a factor of K, ignoring variance (e.g, we will sometimes get unlucky and one disk will have more than 1/Kth of the images we are reading). Therefore if we are running on a distributed filesystem with hundreds of disks, the expected latency will drop to ~15ms.\nLet’s consider a variant where all images are on a single SSD. This changes the sequential read performance to 20µs + 1ms per image, which adds up to ~30 ms overall.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eCool, I calculated per SSD and got it right. But, I guess realistically it would be the HDD version (say S3).\u003c/p\u003e\n\n\u003ch1 id=\"measurement\"\u003eMeasurement\u003c/h1\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe preceding section gives some tips about how to think about performance when writing code without worrying too much about how to measure the performance impact of your choices. However, before you actually start making improvements, or run into a tradeoff involving various things like performance, simplicity, etc. you will want to measure or estimate potential performance benefits. Being able to measure things effectively is the number one tool you’ll want to have in your arsenal when doing performance-related work.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI should really keep this in mind. Esimates before actually going to run it. The question is if it is even feasible.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAs an aside, it’s worth pointing out that profiling code that you’re unfamiliar with can also be a good way of getting a general sense of the structure of the codebase and how it operates. Examining the source code of heavily involved routines in the dynamic call graph of a program can give you a high level sense of “what happens” when running the code, which can then build your own confidence in making performance-improving changes in slightly unfamiliar code.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes! I think just reading code gives one an idealistic view of the code. So much complexity happens behind the scenes. Is there lock contention? False sharing? Too much time spent allocating? Memory leaks? There are things to observe that is not so easily picked up by looking or stuffing the code into a llm.\u003c/p\u003e\n\n\u003ch2 id=\"profiling-tools-and-tips\"\u003eProfiling tools and tips\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you can, write a microbenchmark that covers the code you are improving. Microbenchmarks improve turnaround time when making performance improvements, help verify the impact of performance improvements, and can help prevent future performance regressions. However microbenchmarks can have pitfalls that make them non-representative of full system performance.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eVery true. It helps build understanding of individual components of the system to see where the full system has overhead.\u003c/p\u003e\n\n\u003ch2 id=\"what-to-do-when-profiles-are-flat\"\u003eWhat to do when profiles are flat\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eFind loops closer to the top of call stacks (flame graph view of a CPU profile can be helpful here). Potentially, the loop or the code it calls could be restructured to be more efficient. Some code that initially built a complicated graph structure incrementally by looping over nodes and edges of the input was changed to build the graph structure in one shot by passing it the entire input. This removed a bunch of internal checks that were happening per edge in the initial code.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eReduce loops into an array.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTake a step back and look for structural changes higher up in the call stacks instead of concentrating on micro-optimizations. The techniques listed under algorithmic improvements can be useful when doing this.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAlgorithm level instead of micro optimizations\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLook for overly general code. Replace it with a customized or lower-level implementation. E.g., if an application is repeatedly using a regular expression match where a simple prefix match would suffice, consider dropping the use of the regular expression.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eMakes sense, seems micro optimization, I would be wary of this.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAttempt to reduce the number of allocations: get an allocation profile, and pick away at the highest contributor to the number of allocations. This will have two effects: (1) It will provide a direct reduction of the amount of time spent in the allocator (and garbage collector for GC-ed languages) (2) There will often be a reduction in cache misses since in a long running program using tcmalloc, every allocation tends to go to a different cache line.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSeen this happen SO many times. This takes up so many cycles, it’s actually frustrating to solve.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eGather other types of profiles, specially ones based on hardware performance counters. Such profiles may point out functions that are encountering a high cache miss rate. Techniques described in the profiling tools and tips section can be helpful.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes, but one needs to learn how to these performance counters at a system level and typically they are just samples (hard to pinpoint). I guess perf would help here with something like cache misses\u003c/p\u003e\n\n\u003ch2 id=\"api-considerations\"\u003eAPI considerations\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWidely used APIs come under heavy pressure to add features. Be careful when adding new features since these will constrain future implementations and increase cost unnecessarily for users who don’t need the new features. E.g., many C++ standard library containers promise iterator stability, which in typical implementations increases the number of allocations significantly, even though many users do not need pointer stability.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eMake API as simple as possible, kind of like C i guess? But make the interface actually good.\u003c/p\u003e\n\n\u003ch3 id=\"bulk-apis\"\u003eBulk APIs\u003c/h3\u003e\n\n\u003cp\u003eReduce the number of locks in memory allocations\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003etemplate\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003etypename\u003c/span\u003e \u003cspan class=\"nc\"\u003eT\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eObjectStore\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003epublic:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStatus\u003c/span\u003e \u003cspan class=\"n\"\u003eDeleteRef\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eRef\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"k\"\u003etemplate\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003etypename\u003c/span\u003e \u003cspan class=\"nc\"\u003eT\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eObjectStore\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003epublic:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStatus\u003c/span\u003e \u003cspan class=\"n\"\u003eDeleteRef\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eRef\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Delete many references. For each ref, if no other Refs point to the same\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// object, the object will be deleted. Returns non-OK on any error.\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStatus\u003c/span\u003e \u003cspan class=\"n\"\u003eDeleteRefs\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eSpan\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eRef\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003erefs\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"k\"\u003etemplate\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003etypename\u003c/span\u003e \u003cspan class=\"nc\"\u003eT\u003c/span\u003e\u003cspan class=\"p\"\u003e\u0026gt;\u003c/span\u003e\n\u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStatus\u003c/span\u003e \u003cspan class=\"n\"\u003eObjectStore\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eT\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;::\u003c/span\u003e\u003cspan class=\"n\"\u003eDeleteRefs\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eSpan\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eRef\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003erefs\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eutil\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eStatus\u003c/span\u003e \u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003emu_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eref\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003erefs\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eUpdate\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eDeleteRefLocked\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eref\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"view-types\"\u003eView types\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThese types reduce copying, and allow callers to pick their own container types (e.g., one caller might use std::vector whereas another one uses absl::InlinedVector).\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYep! Been using this\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eFor frequently called routines, sometimes it is useful to allow higher-level callers to pass in a data structure that they own or information that the called routine needs that the client already has. This can avoid the low-level routine being forced to allocate its own temporary data structure or recompute already-available information.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eWallTime\u003c/span\u003e \u003cspan class=\"n\"\u003enow\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eWallTime_Now\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"n\"\u003eRPC_Stats\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRecordRPC\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estats_name\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003em\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eWallTime\u003c/span\u003e \u003cspan class=\"n\"\u003enow\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eWallTime_Now\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"n\"\u003eRPC_Stats\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRecordRPC\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estats_name\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003em\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003enow\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eThis makes sense\u003c/p\u003e\n\n\u003ch3 id=\"thread-compatible-vs-thread-safe-types\"\u003eThread-compatible vs. Thread-safe types\u003c/h3\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eTransferPhase\u003c/span\u003e \u003cspan class=\"n\"\u003eHitlessTransferPhase\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eget\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"n\"\u003eCallsiteMetrics\u003c/span\u003e \u003cspan class=\"n\"\u003ecm\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"s\"\u003e\"HitlessTransferPhase::get\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMonitoredMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ecm\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003emutex_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ephase_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003eTransferPhase\u003c/span\u003e \u003cspan class=\"n\"\u003eHitlessTransferPhase\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eget\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ephase_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eHave the user do the sync, makes sense for performance as the internal calls won’t be always locking\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe most critical opportunities for performance improvements come from algorithmic improvements, e.g., turning an O(N²) algorithm to O(N lg(N)) or O(N), avoiding potentially exponential behavior, etc. These opportunities are rare in stable code, but are worth paying attention to when writing new code. A few examples that show such improvements to pre-existing code:\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eRare in stable code! Man, they must have thought about most things.\u003c/p\u003e\n\n\u003ch2 id=\"better-memory-representation\"\u003eBetter memory representation\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eCareful consideration of memory footprint and cache footprint of important data structures can often yield big savings. The data structures below focus on supporting common operations by touching fewer cache lines. Care taken here can (a) avoid expensive cache misses (b) reduce memory bus traffic, which speeds up both the program in question and anything else running on the same machine\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes, these are expensive resources on any machine.\u003c/p\u003e\n\n\u003ch3 id=\"memory-layout\"\u003eMemory layout\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePlace hot read-only fields away from hot mutable fields so that writes to the mutable fields do not cause the read-only fields to be evicted from nearby caches.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh I get it, writes invalidate other core’s entries\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eConsider packing things into fewer bytes by using bit and byte-level encoding. This can be complicated, so only do this when the data under question is encapsulated inside a well-tested module, and the overall reduction of memory usage is significant. Furthermore, watch out for side effects like under-alignment of frequently used data, or more expensive code for accessing packed representations. Validate such changes using benchmarks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eMakes sense. Trade space for CPU like varint, etc.\u003c/p\u003e\n\n\u003ch3 id=\"indices-instead-of-pointers\"\u003eIndices instead of pointers\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOn modern 64-bit machines, pointers take up 64 bits. If you have a pointer-rich data structure, you can easily chew up lots of memory with indirections of T*. Instead, consider using integer indices into an array T[] or other data structure. Not only will the references be smaller (if the number of indices is small enough to fit in 32 or fewer bits), but the storage for all the T[] elements will be contiguous, often leading to better cache locality.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSmaller indices. 4 bytes = 1billion indices already at 1/2 the storage cost\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAvoid data structures that allocate a separate object per stored element (e.g., std::map, std::unordered_map in C++). Instead, consider types that use chunked or flat representations to store multiple elements in close proximity in memory (e.g., std::vector, absl::flat_hash_{map,set} in C++). Such types tend to have much better cache behavior. Furthermore, they encounter less allocator overhead.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes. But only in performant code. It’s sometimes tricky to have a flat representation, but flat hashmap/set is nice.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOne useful technique is to partition elements into chunks where each chunk can hold a fixed number of elements. This technique can reduce the cache footprint of a data structure significantly while preserving good asymptotic behavior.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes! used in many implementations such as highly performant read/write queues.\u003c/p\u003e\n\n\u003ch3 id=\"arenas\"\u003eArenas\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eArenas can help reduce memory allocation cost, but they also have the benefit of packing together independently allocated items next to each other, typically in fewer cache lines, and eliminating most destruction costs. They are likely most effective for complex data structures with many sub-objects. Consider providing an appropriate initial size for the arena since that can help reduce allocations. Caveat: it is easy to misuse arenas by putting too many short-lived objects in a long-lived arena, which can unnecessarily bloat memory footprint.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eBasically allocate items ahead of time, but may not use the entire arena. It’s tricky to get right… very tricky… especially with estimatingg how big it should be.\u003c/p\u003e\n\n\u003ch3 id=\"arrays-instead-of-maps\"\u003eArrays instead of maps\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf the domain of a map can be represented by a small integer or is an enum, or if the map will have very few elements, the map can sometimes be replaced by an array or a vector of some form.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003egtl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eflat_map\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003epayload_type_to_clock_frequency_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e// A map (implemented as a simple array) indexed by payload_type to clock freq\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e// for that paylaod type (or 0)\u003c/span\u003e\n\u003cspan class=\"k\"\u003estruct\u003c/span\u003e \u003cspan class=\"nc\"\u003ePayloadTypeToClockRateMap\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e128\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n\u003cspan class=\"p\"\u003e};\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003ePayloadTypeToClockRateMap\u003c/span\u003e \u003cspan class=\"n\"\u003epayload_type_to_clock_frequency_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eOnly used when key is index…\u003c/p\u003e\n\n\u003ch3 id=\"bit-vectors-instead-of-sets\"\u003eBit vectors instead of sets\u003c/h3\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eZoneSet\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"k\"\u003epublic\u003c/span\u003e \u003cspan class=\"n\"\u003edense_hash_set\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eZoneId\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003epublic:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eContains\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eZoneId\u003c/span\u003e \u003cspan class=\"n\"\u003ezone\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ecount\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ezone\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eZoneSet\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Returns true iff \"zone\" is contained in the set\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eContainsZone\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eZoneId\u003c/span\u003e \u003cspan class=\"n\"\u003ezone\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ezone\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eb_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003eb_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eget_bit\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ezone\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003eprivate\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003esize_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Number of zones inserted\u003c/span\u003e\n \u003cspan class=\"n\"\u003eutil\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003ebitmap\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eInlinedBitVector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"mi\"\u003e256\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003eb_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eI’ve not actually used this before. Essentially a vector of bits instead of a set of values. I don’t use sets that often…\u003c/p\u003e\n\n\u003ch3 id=\"reduce-allocations\"\u003eReduce allocations\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eNewly-allocated objects may require expensive initialization and sometimes corresponding expensive destruction when no longer needed.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI see this time and time again\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eEvery allocation tends to be on a new cache line and therefore data spread across many independent allocations will have a larger cache footprint than data spread across fewer allocations.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes. Batch your allocations (basically arena)\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eLiveTensor\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eLiveTensor\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eTensor\u003c/span\u003e \u003cspan class=\"n\"\u003et\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eshared_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003edinfo\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003et\u003c/span\u003e\u003cspan class=\"p\"\u003e)),\u003c/span\u003e\n \u003cspan class=\"n\"\u003edevice_info\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edinfo\u003c/span\u003e \u003cspan class=\"o\"\u003e?\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edinfo\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emake_shared\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e()),\u003c/span\u003e\n \u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n\u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eshared_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003eempty_device_info\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eshared_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;*\u003c/span\u003e \u003cspan class=\"n\"\u003eresult\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e\n \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eshared_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eresult\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eLiveTensor\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eLiveTensor\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etf\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eTensor\u003c/span\u003e \u003cspan class=\"n\"\u003et\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eshared_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eDeviceInfo\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003edinfo\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003etensor\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003et\u003c/span\u003e\u003cspan class=\"p\"\u003e)),\u003c/span\u003e \u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eis_batched\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edinfo\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003edevice_info\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edinfo\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003edevice_info\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eempty_device_info\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"resize-or-reserve-containers\"\u003eResize or reserve containers\u003c/h3\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003endocs\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003euint32\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eERRORCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eb\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eGetRice\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003erice_base\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"n\"\u003edocs_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eDocId\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003emy_shard_\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003enum_shards_\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003edocs_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003epush_back\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elast_docid_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"n\"\u003edocs_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eresize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003endocs\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"n\"\u003eDocId\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003edocptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003edocs_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n\u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003endocs\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003euint32\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eERRORCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eb\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eGetRice\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003erice_base\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003edocptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eDocId\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003emy_shard_\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003enum_shards_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003edocptr\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003edelta\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003edocptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003elast_docid_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eI actually do this a lot for my preallocated code for performance. Wow, I guess I do some things correctly\u003c/p\u003e\n\n\u003ch3 id=\"avoid-copying-when-possible\"\u003eAvoid copying when possible\u003c/h3\u003e\n\n\u003cp\u003eOne of the most critical things to do (no work is better than having work)\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003esearch_iterators\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eDocPLIteratorFactory\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eCreate\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eopts\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003esearch_iterators\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eDocPLIteratorFactory\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eCreate\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003emove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eopts\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eiterator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eWrapUnique\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esstable\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eGetIterator\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003edone\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eT\u003c/span\u003e \u003cspan class=\"n\"\u003eprofile\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eprofile\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eParseFromString\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003evalue_view\u003c/span\u003e\u003cspan class=\"p\"\u003e()))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eInternalError\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"s\"\u003e\"Failed to parse mem_block to specified profile type.\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eNext\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003eiterator\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eWrapUnique\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003esstable\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eGetIterator\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n\u003cspan class=\"n\"\u003eT\u003c/span\u003e \u003cspan class=\"n\"\u003eprofile\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003edone\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e!\u003c/span\u003e\u003cspan class=\"n\"\u003eprofile\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eParseFromString\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003evalue_view\u003c/span\u003e\u003cspan class=\"p\"\u003e()))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eInternalError\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"s\"\u003e\"Failed to parse mem_block to specified profile type.\"\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eiterator\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eNext\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOften, code is written to cover all cases, but some subset of the cases are much simpler and more common than others. E.g., vector::push_back usually has enough space for the new element, but contains code to resize the underlying storage when it does not. Some attention paid to the structure of code can help make the common simple case faster without hurting uncommon case performance significantly.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOne has to understand the uncommon case underlying the API call. Say no error happened, we shouldn’t log at all.\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC_Stats_Measurement\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"k\"\u003eoperator\u003c/span\u003e\u003cspan class=\"o\"\u003e+=\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC_Stats_Measurement\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNUM_ERRORS\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eerrors\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e+=\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eerrors\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC_Stats_Measurement\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"k\"\u003eoperator\u003c/span\u003e\u003cspan class=\"o\"\u003e+=\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC_Stats_Measurement\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eany_errors_set\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNUM_ERRORS\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eerrors\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e+=\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eerrors\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"n\"\u003eany_errors_set\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003etrue\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePreallocate 10 nodes not 200 for query handling in Google’s web server. A simple change that reduced web server’s CPU usage by 7.5%. Wow.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003e\u003ccode class=\"language-plaintext highlighter-rouge\"\u003equerytree.h\u003c/code\u003e\u003c/p\u003e\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ekInitParseTreeSize\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e200\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// initial size of querynode pool\u003c/span\u003e\n\u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ekInitParseTreeSize\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e10\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// initial size of querynode pool\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"specialize-code\"\u003eSpecialize code\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eA particular performance-sensitive call-site may not need the full generality provided by a general-purpose library. Consider writing specialized code in such cases instead of calling the general-purpose code if it provides a performance improvement.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting, I haven’t done this before. This should be put into very heavily usedd code.\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003etype\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eMATCH_TYPE_REGEXP\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eterm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eNonMetaPrefix\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003eCopyToString\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eprefix\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eterm\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eRegexpSuffix\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"s\"\u003e\".*\"\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Special case for a regexp that matches anything, so we can\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// bypass RE2::FullMatch\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003etype\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eMATCH_TYPE_PREFIX\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e \u003cspan class=\"k\"\u003eelse\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003etype\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eMATCH_TYPE_REGEXP\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"make-the-compilers-job-easier\"\u003eMake the compiler’s job easier\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe application programmer will often know more about the behavior of the system and can aid the compiler by rewriting the code to operate at a lower level. However, only do this when profiles show an issue since compilers will often get things right on their own. Looking at the generated assembly code for performance critical routines can help you understand if the compiler is “getting it right”. Pprof provides a very helpful display of source code interleaved with disassembly and annotated with performance data.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eIf you understand the code extremely well, you can get to this stage, OR use specific tool that shows the assembly (rare!)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAvoid functions calls in hot functions (allows the compiler to avoid frame setup costs).\nMove slow-path code into a separate tail-called function.\nCopy small amounts of data into local variables before heavy use. This can let the compiler assume there is no aliasing with other data, which may improve auto-vectorization and register allocation.\nHand-unroll very hot loops.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eKey\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eInitSeps\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003estart\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003erep_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003erep_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003estart\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eDCHECK_GE\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK_LT\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003es\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\u003cspan class=\"n\"\u003ememchr\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e \u003cspan class=\"o\"\u003e!=\u003c/span\u003e \u003cspan class=\"nb\"\u003eNULL\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eseps_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"kr\"\u003einline\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"nf\"\u003eScanBackwardsForSep\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e-=\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u0026amp;\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e!=\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e--\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eKey\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eInitSeps\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003estart\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003erep_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003erep_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003estart\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eDCHECK_GE\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK_LT\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// We go backwards from the end of the string, rather than forwards,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// since the directory name might be long and definitely doesn't contain\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// any '#' characters.\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eScanBackwardsForSep\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003elimit\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eseps_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e2\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e--\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eScanBackwardsForSep\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eseps_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e--\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eScanBackwardsForSep\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eDCHECK\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"sc\"\u003e'#'\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eseps_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"reduce-stats-collection-costs\"\u003eReduce stats collection costs\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBalance the utility of stats and other behavioral information about a system against the cost of maintaining that information. The extra information can often help people to understand and improve high-level behavior, but can also be costly to maintain.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes, I’ve seen this. How do you decide what to instrument essentially?\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003ePart\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003echanges\u003c/span\u003e \u003cspan class=\"n\"\u003ethat\u003c/span\u003e \u003cspan class=\"n\"\u003ereduce\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003esetting\u003c/span\u003e \u003cspan class=\"n\"\u003ean\u003c/span\u003e \u003cspan class=\"n\"\u003ealarm\u003c/span\u003e \u003cspan class=\"n\"\u003efrom\u003c/span\u003e \u003cspan class=\"mi\"\u003e771\u003c/span\u003e \u003cspan class=\"n\"\u003ens\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"mi\"\u003e271\u003c/span\u003e \u003cspan class=\"n\"\u003ens\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eselectserver\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eh\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eSelectServer\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003epublic:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"nl\"\u003eprotected:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003escoped_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eMinuteTenMinuteHourStat\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003enum_alarms_stat_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003escoped_ptr\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eMinuteTenMinuteHourStat\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003enum_closures_stat_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e};\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e// Selectserver class\u003c/span\u003e\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eSelectServer\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"nl\"\u003eprotected:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e};\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"n\"\u003eselectserver\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eSelectServer\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eAddAlarmInternal\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eAlarmer\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ealarmer\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eoffset_in_ms\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eid\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eis_periodic\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003ealarms_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003einsert\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealarm\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003enum_alarms_stat_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eIncBy\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eSelectServer\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eAddAlarmInternal\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eAlarmer\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ealarmer\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eoffset_in_ms\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eid\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eis_periodic\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003ealarms_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eAdd\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealarm\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"n\"\u003eselectserver\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eSelectServer\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRemoveAlarm\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eAlarmer\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ealarmer\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eid\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003ealarms_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eerase\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealarm\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003enum_alarms_stat_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eIncBy\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eSelectServer\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRemoveAlarm\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eAlarmer\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ealarmer\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eid\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003ealarms_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eRemove\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ealarm\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003eOften\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003estats\u003c/span\u003e \u003cspan class=\"n\"\u003eor\u003c/span\u003e \u003cspan class=\"n\"\u003eother\u003c/span\u003e \u003cspan class=\"n\"\u003eproperties\u003c/span\u003e \u003cspan class=\"n\"\u003ecan\u003c/span\u003e \u003cspan class=\"n\"\u003ebe\u003c/span\u003e \u003cspan class=\"n\"\u003emaintained\u003c/span\u003e \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003esample\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003eelements\u003c/span\u003e \u003cspan class=\"n\"\u003ehandled\u003c/span\u003e \u003cspan class=\"n\"\u003eby\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003esystem\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ee\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eg\u003c/span\u003e\u003cspan class=\"p\"\u003e.,\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC\u003c/span\u003e \u003cspan class=\"n\"\u003erequests\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003einput\u003c/span\u003e \u003cspan class=\"n\"\u003erecords\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eusers\u003c/span\u003e\u003cspan class=\"p\"\u003e).\u003c/span\u003e \u003cspan class=\"n\"\u003eMany\u003c/span\u003e \u003cspan class=\"n\"\u003esubsystems\u003c/span\u003e \u003cspan class=\"n\"\u003euse\u003c/span\u003e \u003cspan class=\"k\"\u003ethis\u003c/span\u003e \u003cspan class=\"n\"\u003eapproach\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etcmalloc\u003c/span\u003e \u003cspan class=\"n\"\u003eallocation\u003c/span\u003e \u003cspan class=\"n\"\u003etracking\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"n\"\u003erequestz\u003c/span\u003e \u003cspan class=\"n\"\u003estatus\u003c/span\u003e \u003cspan class=\"n\"\u003epages\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eDapper\u003c/span\u003e \u003cspan class=\"n\"\u003esamples\u003c/span\u003e\u003cspan class=\"p\"\u003e).\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eWhen\u003c/span\u003e \u003cspan class=\"n\"\u003esampling\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003econsider\u003c/span\u003e \u003cspan class=\"n\"\u003ereducing\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003esampling\u003c/span\u003e \u003cspan class=\"n\"\u003erate\u003c/span\u003e \u003cspan class=\"n\"\u003ewhen\u003c/span\u003e \u003cspan class=\"n\"\u003eappropriate\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eThis\u003c/span\u003e \u003cspan class=\"n\"\u003echange\u003c/span\u003e \u003cspan class=\"n\"\u003ereduces\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003esampling\u003c/span\u003e \u003cspan class=\"n\"\u003erate\u003c/span\u003e \u003cspan class=\"n\"\u003efrom\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"mi\"\u003e10\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"mf\"\u003e32.\u003c/span\u003e \u003cspan class=\"n\"\u003eFurthermore\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ewe\u003c/span\u003e \u003cspan class=\"n\"\u003enow\u003c/span\u003e \u003cspan class=\"n\"\u003ekeep\u003c/span\u003e \u003cspan class=\"n\"\u003eexecution\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e \u003cspan class=\"n\"\u003estats\u003c/span\u003e \u003cspan class=\"n\"\u003ejust\u003c/span\u003e \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003esampled\u003c/span\u003e \u003cspan class=\"n\"\u003eevents\u003c/span\u003e \u003cspan class=\"n\"\u003eand\u003c/span\u003e \u003cspan class=\"n\"\u003espeed\u003c/span\u003e \u003cspan class=\"n\"\u003eup\u003c/span\u003e \u003cspan class=\"n\"\u003esampling\u003c/span\u003e \u003cspan class=\"n\"\u003edecisions\u003c/span\u003e \u003cspan class=\"n\"\u003eby\u003c/span\u003e \u003cspan class=\"k\"\u003eusing\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003epower\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003etwo\u003c/span\u003e \u003cspan class=\"n\"\u003emodulus\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eThis\u003c/span\u003e \u003cspan class=\"n\"\u003ecode\u003c/span\u003e \u003cspan class=\"n\"\u003eis\u003c/span\u003e \u003cspan class=\"n\"\u003ecalled\u003c/span\u003e \u003cspan class=\"n\"\u003eon\u003c/span\u003e \u003cspan class=\"n\"\u003eevery\u003c/span\u003e \u003cspan class=\"n\"\u003epacket\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003eGoogle\u003c/span\u003e \u003cspan class=\"n\"\u003eMeet\u003c/span\u003e \u003cspan class=\"n\"\u003evideo\u003c/span\u003e \u003cspan class=\"n\"\u003econferencing\u003c/span\u003e \u003cspan class=\"n\"\u003esystem\u003c/span\u003e \u003cspan class=\"n\"\u003eand\u003c/span\u003e \u003cspan class=\"n\"\u003eneeded\u003c/span\u003e \u003cspan class=\"n\"\u003eperformance\u003c/span\u003e \u003cspan class=\"n\"\u003ework\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"n\"\u003ekeep\u003c/span\u003e \u003cspan class=\"n\"\u003eup\u003c/span\u003e \u003cspan class=\"n\"\u003ewith\u003c/span\u003e \u003cspan class=\"n\"\u003ecapacity\u003c/span\u003e \u003cspan class=\"n\"\u003edemands\u003c/span\u003e \u003cspan class=\"n\"\u003eduring\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003efirst\u003c/span\u003e \u003cspan class=\"n\"\u003epart\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003eCOVID\u003c/span\u003e \u003cspan class=\"n\"\u003eoutbreak\u003c/span\u003e \u003cspan class=\"n\"\u003eas\u003c/span\u003e \u003cspan class=\"n\"\u003eusers\u003c/span\u003e \u003cspan class=\"n\"\u003erapidly\u003c/span\u003e \u003cspan class=\"n\"\u003emigrated\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"n\"\u003edoing\u003c/span\u003e \u003cspan class=\"n\"\u003emore\u003c/span\u003e \u003cspan class=\"n\"\u003eonline\u003c/span\u003e \u003cspan class=\"n\"\u003emeetings\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eScopedPerformanceMeasurement\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003epublic:\u003c/span\u003e\n \u003cspan class=\"k\"\u003eexplicit\u003c/span\u003e \u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ePacketExecutor\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003epacket_executor_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003etracer_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor_trace_threshold_\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003ekClosureTraceName\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// ThreadCPUUsage is an expensive call. At the time of writing,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// it takes over 400ns, or roughly 30 times slower than absl::Now,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// so we sample only 10% of closures to keep the cost down.\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eclosures_executed_\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e \u003cspan class=\"mi\"\u003e10\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ethread_cpu_usage_start_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eThreadCPUUsage\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Sample start time after potentially making the above expensive call,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// so as not to pollute wall time measurements.\u003c/span\u003e\n \u003cspan class=\"n\"\u003erun_start_time_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNow\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"o\"\u003e~\u003c/span\u003e\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePacketExecutor\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003epacket_executor_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003etracer_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor_trace_threshold_\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003ekClosureTraceName\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// ThreadCPUUsage is an expensive call. At the time of writing,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// it takes over 400ns, or roughly 30 times slower than absl::Now,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// so we sample only 1 in 32 closures to keep the cost down.\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eclosures_executed_\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e \u003cspan class=\"mi\"\u003e32\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ethread_cpu_usage_start_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ebase\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eThreadCPUUsage\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// Sample start time after potentially making the above expensive call,\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// so as not to pollute wall time measurements.\u003c/span\u003e\n \u003cspan class=\"n\"\u003erun_start_time_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNow\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003epacket_executor\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"o\"\u003e~\u003c/span\u003e\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erun_end_time\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNow\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erun_duration\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003erun_end_time\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003erun_start_time_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ethread_cpu_usage_start_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ehas_value\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003eclosure_execution_time\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eRecord\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eToInt64Microseconds\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003erun_duration\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"o\"\u003e::~\u003c/span\u003e\u003cspan class=\"n\"\u003eScopedPerformanceMeasurement\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erun_end_time\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eNow\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"k\"\u003eauto\u003c/span\u003e \u003cspan class=\"n\"\u003erun_duration\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003erun_end_time\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003erun_start_time_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ethread_cpu_usage_start_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ehas_value\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eclosure_execution_time\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eRecord\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eToInt64Microseconds\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003erun_duration\u003c/span\u003e\u003cspan class=\"p\"\u003e));\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"avoid-logging-on-hot-code-paths\"\u003eAvoid logging on hot code paths\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLogging statements can be costly, even if the logging-level for the statement doesn’t actually log anything. E.g., ABSL_VLOG’s implementation requires at least a load and a comparison, which may be a problem in hot code paths. In addition, the presence of the logging code may inhibit compiler optimizations. Consider dropping logging entirely from hot code paths.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eimage_similarity\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_subimage_size_y\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej1\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003erad\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_to_integral_subimage_y\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej2\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ej1\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e2\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003erad\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create a pointer for this row's output, taking into account the offset\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// to the full image.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eimage_diff_ptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eimage_diff\u003c/span\u003e\u003cspan class=\"p\"\u003e)(\u003c/span\u003e\u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003emin_j\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003emin_i\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_subimage_size_x\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eVLOG_IS_ON\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e))\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003evlog_3\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eDEBUG_MODE\u003c/span\u003e \u003cspan class=\"o\"\u003e?\u003c/span\u003e \u003cspan class=\"n\"\u003eVLOG_IS_ON\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"nb\"\u003efalse\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_subimage_size_y\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej1\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003erad\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_to_integral_subimage_y\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ej2\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ej1\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e2\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003erad\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Create a pointer for this row's output, taking into account the offset\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// to the full image.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eimage_diff_ptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eimage_diff\u003c/span\u003e\u003cspan class=\"p\"\u003e)(\u003c/span\u003e\u003cspan class=\"n\"\u003ej\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003emin_j\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003emin_i\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003eoutput_subimage_size_x\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003evlog_3\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003eRun\u003c/span\u003e \u003cspan class=\"nf\"\u003eon\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e40\u003c/span\u003e \u003cspan class=\"n\"\u003eX\u003c/span\u003e \u003cspan class=\"mi\"\u003e2801\u003c/span\u003e \u003cspan class=\"n\"\u003eMHz\u003c/span\u003e \u003cspan class=\"n\"\u003eCPUs\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e \u003cspan class=\"mi\"\u003e2016\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mo\"\u003e05\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e16\u003c/span\u003e\u003cspan class=\"n\"\u003eT15\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mi\"\u003e55\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mf\"\u003e32.250633072\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mo\"\u003e07\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mo\"\u003e00\u003c/span\u003e\n\u003cspan class=\"n\"\u003eCPU\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003eIntel\u003c/span\u003e \u003cspan class=\"n\"\u003eIvybridge\u003c/span\u003e \u003cspan class=\"n\"\u003ewith\u003c/span\u003e \u003cspan class=\"n\"\u003eHyperThreading\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e20\u003c/span\u003e \u003cspan class=\"n\"\u003ecores\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"n\"\u003edL1\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mi\"\u003e32\u003c/span\u003e\u003cspan class=\"n\"\u003eKB\u003c/span\u003e \u003cspan class=\"n\"\u003edL2\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mi\"\u003e256\u003c/span\u003e\u003cspan class=\"n\"\u003eKB\u003c/span\u003e \u003cspan class=\"n\"\u003edL3\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\u003cspan class=\"mi\"\u003e25\u003c/span\u003e\u003cspan class=\"n\"\u003eMB\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBenchmark\u003c/span\u003e \u003cspan class=\"n\"\u003eBase\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ens\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"n\"\u003eNew\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ens\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"n\"\u003eImprovement\u003c/span\u003e\n\u003cspan class=\"o\"\u003e------------------------------------------------------------------\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e16\u003c/span\u003e \u003cspan class=\"mi\"\u003e29104\u003c/span\u003e \u003cspan class=\"mi\"\u003e26372\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e9.4\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e64\u003c/span\u003e \u003cspan class=\"mi\"\u003e473235\u003c/span\u003e \u003cspan class=\"mi\"\u003e425281\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e10.1\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e512\u003c/span\u003e \u003cspan class=\"mi\"\u003e30246238\u003c/span\u003e \u003cspan class=\"mi\"\u003e27622009\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e8.7\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"n\"\u003ek\u003c/span\u003e \u003cspan class=\"mi\"\u003e125651445\u003c/span\u003e \u003cspan class=\"mi\"\u003e113361991\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e9.8\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCLimitedBoundsPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e16\u003c/span\u003e \u003cspan class=\"mi\"\u003e8314\u003c/span\u003e \u003cspan class=\"mi\"\u003e7498\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e9.8\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCLimitedBoundsPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e64\u003c/span\u003e \u003cspan class=\"mi\"\u003e143508\u003c/span\u003e \u003cspan class=\"mi\"\u003e132202\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e7.9\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCLimitedBoundsPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e512\u003c/span\u003e \u003cspan class=\"mi\"\u003e9335684\u003c/span\u003e \u003cspan class=\"mi\"\u003e8477567\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e9.2\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003cspan class=\"n\"\u003eBM_NCCLimitedBoundsPerformance\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"n\"\u003ek\u003c/span\u003e \u003cspan class=\"mi\"\u003e37223897\u003c/span\u003e \u003cspan class=\"mi\"\u003e34201739\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e\u003cspan class=\"mf\"\u003e8.1\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch2 id=\"code-size-considerations\"\u003eCode size considerations\u003c/h2\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePerformance encompasses more than just runtime speed. Sometimes it is worth considering the effects of software choices on the size of generated code. Large code size means longer compile and link times, bloated binaries, more memory usage, more icache pressure, and other sometimes negative effects on microarchitectural structures like branch predictors, etc. Thinking about these issues is especially important when writing low-level library code that will be used in many places, or when writing templated code that you expect will be instantiated for many different types.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003ern\u003c/span\u003e \u003cspan class=\"n\"\u003emany\u003c/span\u003e \u003cspan class=\"n\"\u003emap\u003c/span\u003e \u003cspan class=\"n\"\u003einsertion\u003c/span\u003e \u003cspan class=\"n\"\u003ecalls\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003erow\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"n\"\u003einitialize\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003ehash\u003c/span\u003e \u003cspan class=\"n\"\u003etable\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003eemoji\u003c/span\u003e \u003cspan class=\"n\"\u003echaracters\u003c/span\u003e \u003cspan class=\"n\"\u003einto\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003esingle\u003c/span\u003e \u003cspan class=\"n\"\u003ebulk\u003c/span\u003e \u003cspan class=\"n\"\u003einsert\u003c/span\u003e \u003cspan class=\"nf\"\u003eoperation\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e188\u003c/span\u003e\u003cspan class=\"n\"\u003eKB\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003etext\u003c/span\u003e \u003cspan class=\"n\"\u003edown\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"mi\"\u003e360\u003c/span\u003e \u003cspan class=\"n\"\u003ebytes\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003elibrary\u003c/span\u003e \u003cspan class=\"n\"\u003elinked\u003c/span\u003e \u003cspan class=\"n\"\u003einto\u003c/span\u003e \u003cspan class=\"n\"\u003emany\u003c/span\u003e \u003cspan class=\"n\"\u003ebinaries\u003c/span\u003e\u003cspan class=\"p\"\u003e).\u003c/span\u003e \u003cspan class=\"err\"\u003e😊\u003c/span\u003e\n\u003cspan class=\"n\"\u003etextfallback_init\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eh\u003c/span\u003e\n\n\u003cspan class=\"kr\"\u003einline\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eAddEmojiFallbacks\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eTextFallbackMap\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE000\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE000\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE001\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE001\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE002\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE002\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE003\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE003\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE004\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE004\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE005\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE005\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFEE7D\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFEE7D\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFEEA0\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFEEA0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"mh\"\u003e0xFE331\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003ekFE331\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e};\u003c/span\u003e\n\u003cspan class=\"kr\"\u003einline\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003eAddEmojiFallbacks\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eTextFallbackMap\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#define PAIR(x) {0x##x, \u0026amp;k##x}\n\u003c/span\u003e \u003cspan class=\"c1\"\u003e// clang-format off\u003c/span\u003e\n \u003cspan class=\"n\"\u003emap\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003einsert\u003c/span\u003e\u003cspan class=\"p\"\u003e({\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE000\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE001\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE002\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE003\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE004\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE005\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFEE7D\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFEEA0\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e\n \u003cspan class=\"n\"\u003ePAIR\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eFE331\u003c/span\u003e\u003cspan class=\"p\"\u003e)});\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// clang-format on\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#undef PAIR\n\u003c/span\u003e\u003cspan class=\"p\"\u003e};\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"parallelization-and-synchronization\"\u003eParallelization and synchronization\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eModern machines have many cores, and they are often underutilized. Expensive work may therefore be completed faster by parallelizing it. The most common approach is to process different items in parallel and combine the results when done. Typically, the items are first partitioned into batches to avoid paying the cost of running something in parallel per item.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eFour\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003eway\u003c/span\u003e \u003cspan class=\"n\"\u003eparallelization\u003c/span\u003e \u003cspan class=\"n\"\u003eimproves\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003erate\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003eencoding\u003c/span\u003e \u003cspan class=\"n\"\u003etokens\u003c/span\u003e \u003cspan class=\"n\"\u003eby\u003c/span\u003e \u003cspan class=\"o\"\u003e~\u003c/span\u003e\u003cspan class=\"mf\"\u003e3.6\u003c/span\u003e\u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\u003cspan class=\"n\"\u003eblocked\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003etoken\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003ecoder\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"nf\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003eencoder_threads_lock\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eencoder_threads\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"nb\"\u003eNULL\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eencoder_threads\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003eThreadPool\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eNumCPUs\u003c/span\u003e\u003cspan class=\"p\"\u003e());\u003c/span\u003e\n \u003cspan class=\"n\"\u003eencoder_threads\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eSetStackSize\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e262144\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eencoder_threads\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eStartWorkers\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003eencoder_threads\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eAdd\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eNewCallback\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003ethis\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003eBlockedTokenEncoder\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eEncodeRegionInThread\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003eregion_tokens\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eregion\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003estats\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"n\"\u003econtroller_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eGetClosureWithCost\u003c/span\u003e\n \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eNewCallback\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003eDummyCallback\u003c/span\u003e\u003cspan class=\"p\"\u003e),\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e)));\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe effect on system performance should be measured carefully – if spare CPU is not available, or if memory bandwidth is saturated, parallelization may not help, or may even hurt.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is the caveat. It’s hard to guage this for every type of machine there is.\u003c/p\u003e\n\n\u003ch3 id=\"amortize-lock-acquisition\"\u003eAmortize lock acquisition\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAvoid fine-grained locking to reduce the cost of Mutex operations in hot paths. Caveat: this should only be done if the change does not increase lock contention.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting. Yes, if there is another thread accessing it, it could theorically be faster (say some section isn’t using actually the shared variable)\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"c1\"\u003e// Acquire lock once to free entire tree of query nodes, rather than reacquiring lock for every node in tree.\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e// Pool of query nodes\u003c/span\u003e\n\u003cspan class=\"n\"\u003eThreadSafeFreeList\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003epool_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e256\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRelease\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"nb\"\u003eNULL\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e \u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003eRelease\u003c/span\u003e\u003cspan class=\"p\"\u003e((\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]);\u003c/span\u003e\n \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eclear\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003epool_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eDelete\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e// Pool of query nodes\u003c/span\u003e\n\u003cspan class=\"n\"\u003eMutex\u003c/span\u003e \u003cspan class=\"n\"\u003epool_lock_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"n\"\u003eFreeList\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003epool_\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e256\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eRelease\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"nb\"\u003eNULL\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003epool_lock_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eReleaseLocked\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eReleaseLocked\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eMustangQuery\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#ifndef NDEBUG\n\u003c/span\u003e \u003cspan class=\"n\"\u003epool_lock_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eAssertHeld\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#endif\n\u003c/span\u003e \u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"nb\"\u003eNULL\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e \u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003eReleaseLocked\u003c/span\u003e\u003cspan class=\"p\"\u003e((\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"p\"\u003e)[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]);\u003c/span\u003e\n \u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003echildren_\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u0026gt;\u003c/span\u003e\u003cspan class=\"n\"\u003eclear\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003epool_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eDelete\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003enode\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"keep-critical-sections-short\"\u003eKeep critical sections short\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAvoid expensive work inside critical sections. In particular, watch out for innocuous looking code that might be doing RPCs or accessing files.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eBasically minimize critical sections, but in addition, try to find these critical sections that have high ROI.\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eAvoid\u003c/span\u003e \u003cspan class=\"n\"\u003eRPC\u003c/span\u003e \u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"n\"\u003eholding\u003c/span\u003e \u003cspan class=\"n\"\u003eMutex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\u003cspan class=\"n\"\u003etrainer\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Notify the parameter server that we are starting.\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003elock_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003emodel_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003emodel\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMaybeRecordProgress\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elast_global_step_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"kt\"\u003ebool\u003c/span\u003e \u003cspan class=\"n\"\u003eshould_start_record_progress\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nb\"\u003efalse\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"n\"\u003eint64\u003c/span\u003e \u003cspan class=\"n\"\u003estep_for_progress\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Notify the parameter server that we are starting.\u003c/span\u003e\n \u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003elock_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003emodel_\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003emodel\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eshould_start_record_progress\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eShouldStartRecordProgress\u003c/span\u003e\u003cspan class=\"p\"\u003e();\u003c/span\u003e\n \u003cspan class=\"n\"\u003estep_for_progress\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003elast_global_step_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"k\"\u003eif\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eshould_start_record_progress\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eStartRecordProgress\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003estep_for_progress\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"reduce-contention-by-sharding\"\u003eReduce contention by sharding\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSometimes a data structure protected by a Mutex that is exhibiting high contention can be safely split into multiple shards, each shard with its own Mutex. (Note: this requires that there are no cross-shard invariants between the different shards.)\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis just means that the underlying elements can be processed in parallel, but the global object cannot be accessed during this time. I didn’t realize you just could initialize multiple copies.\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eShardedLRUCache\u003c/span\u003e \u003cspan class=\"o\"\u003e:\u003c/span\u003e \u003cspan class=\"k\"\u003epublic\u003c/span\u003e \u003cspan class=\"n\"\u003eCache\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"nl\"\u003eprivate:\u003c/span\u003e\n \u003cspan class=\"n\"\u003eLRUCache\u003c/span\u003e \u003cspan class=\"n\"\u003eshard_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ekNumShards\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n \u003cspan class=\"n\"\u003eport\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eMutex\u003c/span\u003e \u003cspan class=\"n\"\u003eid_mutex_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"kt\"\u003euint64_t\u003c/span\u003e \u003cspan class=\"n\"\u003elast_id_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"kr\"\u003einline\u003c/span\u003e \u003cspan class=\"kt\"\u003euint32_t\u003c/span\u003e \u003cspan class=\"n\"\u003eHashSlice\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eSlice\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eHash\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003edata\u003c/span\u003e\u003cspan class=\"p\"\u003e(),\u003c/span\u003e \u003cspan class=\"n\"\u003es\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003esize\u003c/span\u003e\u003cspan class=\"p\"\u003e(),\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\n \u003cspan class=\"k\"\u003estatic\u003c/span\u003e \u003cspan class=\"kt\"\u003euint32_t\u003c/span\u003e \u003cspan class=\"nf\"\u003eShard\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003euint32_t\u003c/span\u003e \u003cspan class=\"n\"\u003ehash\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ehash\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026gt;\u0026gt;\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"mi\"\u003e32\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e \u003cspan class=\"n\"\u003ekNumShardBits\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003evirtual\u003c/span\u003e \u003cspan class=\"n\"\u003eHandle\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"nf\"\u003eLookup\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eSlice\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003ekey\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003euint32_t\u003c/span\u003e \u003cspan class=\"n\"\u003ehash\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eHashSlice\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ekey\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003eshard_\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003eShard\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ehash\u003c/span\u003e\u003cspan class=\"p\"\u003e)].\u003c/span\u003e\u003cspan class=\"n\"\u003eLookup\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ekey\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ehash\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBe careful with the information used for shard selection. If, for example, you use some bits of a hash value for shard selection and then those same bits end up being used again later, the latter use may perform poorly since it sees a skewed distribution of hash values.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eFor sharding, equal distribution is always important. Nothing should be overloaded.\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eThis\u003c/span\u003e \u003cspan class=\"n\"\u003eCL\u003c/span\u003e \u003cspan class=\"n\"\u003epartitions\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003eActiveCallMap\u003c/span\u003e \u003cspan class=\"n\"\u003einto\u003c/span\u003e \u003cspan class=\"mi\"\u003e64\u003c/span\u003e \u003cspan class=\"n\"\u003eshards\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eEach\u003c/span\u003e \u003cspan class=\"n\"\u003eshard\u003c/span\u003e \u003cspan class=\"n\"\u003eis\u003c/span\u003e \u003cspan class=\"k\"\u003eprotected\u003c/span\u003e \u003cspan class=\"n\"\u003eby\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003eseparate\u003c/span\u003e \u003cspan class=\"n\"\u003emutex\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eA\u003c/span\u003e \u003cspan class=\"n\"\u003egiven\u003c/span\u003e \u003cspan class=\"n\"\u003etransaction\u003c/span\u003e \u003cspan class=\"n\"\u003ewill\u003c/span\u003e \u003cspan class=\"n\"\u003ebe\u003c/span\u003e \u003cspan class=\"n\"\u003emapped\u003c/span\u003e \u003cspan class=\"n\"\u003eto\u003c/span\u003e \u003cspan class=\"n\"\u003eexactly\u003c/span\u003e \u003cspan class=\"n\"\u003eone\u003c/span\u003e \u003cspan class=\"n\"\u003eshard\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eA\u003c/span\u003e \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003einterface\u003c/span\u003e \u003cspan class=\"nf\"\u003eLockedShard\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etid\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"n\"\u003eis\u003c/span\u003e \u003cspan class=\"n\"\u003eadded\u003c/span\u003e \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003eaccessing\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003eActiveCallMap\u003c/span\u003e \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003etransaction\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"kr\"\u003ethread\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003esafe\u003c/span\u003e \u003cspan class=\"n\"\u003emanner\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eExample\u003c/span\u003e \u003cspan class=\"n\"\u003eusage\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003etransaction_manager\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eMutexLock\u003c/span\u003e \u003cspan class=\"n\"\u003el\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003eactive_calls_in_mu_\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003edelayed_locks_timer_ring_\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eAdd\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edelayed_locks_flush_time_ms\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etid\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eActiveCalls\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eLockedShard\u003c/span\u003e \u003cspan class=\"n\"\u003eshard\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eactive_calls_in_\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etid\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eshard\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003edelayed_locks_timer_ring\u003c/span\u003e\u003cspan class=\"p\"\u003e().\u003c/span\u003e\u003cspan class=\"n\"\u003eAdd\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003edelayed_locks_flush_time_ms\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003etid\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003cspan class=\"n\"\u003eThe\u003c/span\u003e \u003cspan class=\"n\"\u003eresults\u003c/span\u003e \u003cspan class=\"n\"\u003eshow\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"mi\"\u003e69\u003c/span\u003e\u003cspan class=\"o\"\u003e%\u003c/span\u003e \u003cspan class=\"n\"\u003ereduction\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003eoverall\u003c/span\u003e \u003cspan class=\"n\"\u003ewall\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003eclock\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e \u003cspan class=\"n\"\u003ewhen\u003c/span\u003e \u003cspan class=\"n\"\u003erunning\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003ebenchmark\u003c/span\u003e \u003cspan class=\"n\"\u003ewith\u003c/span\u003e \u003cspan class=\"mi\"\u003e8192\u003c/span\u003e \u003cspan class=\"n\"\u003efibers\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"reduce-false-sharing\"\u003eReduce false sharing\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf different threads access different mutable data, consider placing the different data items on different cache lines, e.g., in C++ using the alignas directive. However, these directives are easy to misuse and may increase object sizes significantly, so make sure performance measurements justify their use.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTrade size for performance… How do you even identify such a thing\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003ehistogram\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eh\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eHistogramOptions\u003c/span\u003e \u003cspan class=\"n\"\u003eoptions_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"n\"\u003einternal\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eHistogramBoundaries\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eboundaries_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003ebuckets_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003emin_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Minimum.\u003c/span\u003e\n\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003emax_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Maximum.\u003c/span\u003e\n\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003ecount_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Total count of occurrences.\u003c/span\u003e\n\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003esum_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Sum of values.\u003c/span\u003e\n\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003esum_of_squares_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Sum of squares of values.\u003c/span\u003e\n\u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"n\"\u003eRegisterVariableExporter\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eexporter_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eHistogramOptions\u003c/span\u003e \u003cspan class=\"n\"\u003eoptions_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003einternal\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eHistogramBoundaries\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eboundaries_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003eRegisterVariableExporter\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eexporter_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Place the following fields in a dedicated cacheline as they are frequently\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// mutated, so we can avoid potential false sharing.\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#ifndef SWIG\n\u003c/span\u003e \u003cspan class=\"k\"\u003ealignas\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eABSL_CACHELINE_SIZE\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#endif\n\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"kt\"\u003edouble\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003ebuckets_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003emin_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Minimum.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003emax_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Maximum.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003ecount_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Total count of occurrences.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003esum_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Sum of values.\u003c/span\u003e\n \u003cspan class=\"kt\"\u003edouble\u003c/span\u003e \u003cspan class=\"n\"\u003esum_of_squares_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Sum of squares of values.\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch3 id=\"reduce-frequency-of-context-switches\"\u003eReduce frequency of context switches\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eProcess small work items inline instead of on device thread pool.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003ehard to see this without a tracer\u003c/p\u003e\n\n\u003ch3 id=\"consider-lock-free-approaches\"\u003eConsider lock-free approaches\u003c/h3\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSometimes lock-free data structures can make a difference over more conventional mutex-protected data structures. However, direct atomic variable manipulation can be dangerous. Prefer higher-level abstractions.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eExtremely hard to debug and catch issues with. I don’t have expertise in this.\u003c/p\u003e\n\n\u003ch3 id=\"protocol-buffer-advice\"\u003eProtocol Buffer advice\u003c/h3\u003e\n\n\u003cp\u003eI think this section is rather huge and for a good reason. Messages are one of the foundational building blocks of any distributed system and optimizing a small percentage will have high yields. This section is for good practices which can be applied to any message protocol.\u003c/p\u003e\n\n\u003cp\u003eWhat I mostly got from this section are that you need to see the generated serialization code and its overhead for serialization/deserialization and find the best practices to reduce the serialization/deserialization overhead (either by editing the proto file or by editing c++, by adding arenas).\u003c/p\u003e\n\n\u003ch3 id=\"c-specific-advice\"\u003eC++-Specific advice\u003c/h3\u003e\n\n\u003cp\u003eabsl::flat_hash_map (and set). This is generall true for almost all standard libraries in C++ except a very small subset (like std::vector).\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eabsl::InlinedVector stores a small number of elements inline (configurable via the second template argument). This enables small vectors up to this number of elements to generally have better cache efficiency and also to avoid allocating a backing store array at all when the number of elements is small.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is probably just allocating on the stack. it’s nice, similiar to llvm::vector\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003egtl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector32\u003c/span\u003e\n\u003cspan class=\"n\"\u003eSaves\u003c/span\u003e \u003cspan class=\"n\"\u003espace\u003c/span\u003e \u003cspan class=\"n\"\u003eby\u003c/span\u003e \u003cspan class=\"k\"\u003eusing\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003ecustomized\u003c/span\u003e \u003cspan class=\"n\"\u003evector\u003c/span\u003e \u003cspan class=\"n\"\u003etype\u003c/span\u003e \u003cspan class=\"n\"\u003ethat\u003c/span\u003e \u003cspan class=\"n\"\u003eonly\u003c/span\u003e \u003cspan class=\"n\"\u003esupports\u003c/span\u003e \u003cspan class=\"n\"\u003esizes\u003c/span\u003e \u003cspan class=\"n\"\u003ethat\u003c/span\u003e \u003cspan class=\"n\"\u003efit\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"mi\"\u003e32\u003c/span\u003e \u003cspan class=\"n\"\u003ebits\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003eSimple\u003c/span\u003e \u003cspan class=\"n\"\u003etype\u003c/span\u003e \u003cspan class=\"n\"\u003echange\u003c/span\u003e \u003cspan class=\"n\"\u003esaves\u003c/span\u003e \u003cspan class=\"o\"\u003e~\u003c/span\u003e\u003cspan class=\"mi\"\u003e8\u003c/span\u003e\u003cspan class=\"n\"\u003eTiB\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003ememory\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003eSpanner\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\u003cspan class=\"n\"\u003etable_ply\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eh\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003eclass\u003c/span\u003e \u003cspan class=\"nc\"\u003eTablePly\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Returns the set of data columns stored in this file for this table.\u003c/span\u003e\n \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eFamilyId\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u0026amp;\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"k\"\u003eprivate\u003c/span\u003e\u003cspan class=\"o\"\u003e:\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"n\"\u003estd\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eFamilyId\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"c1\"\u003e// Data columns in the table.\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#include\u003c/span\u003e \u003cspan class=\"cpf\"\u003e\"util/gtl/vector32.h\"\u003c/span\u003e\u003cspan class=\"cp\"\u003e\n\u003c/span\u003e \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Returns the set of data columns stored in this file for this table.\u003c/span\u003e\n \u003cspan class=\"n\"\u003eabsl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eSpan\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"n\"\u003eFamilyId\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e \u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n\n \u003cspan class=\"p\"\u003e...\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// Data columns in the table.\u003c/span\u003e\n \u003cspan class=\"n\"\u003egtl\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003evector32\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e\u003cspan class=\"n\"\u003eFamilyId\u003c/span\u003e\u003cspan class=\"o\"\u003e\u0026gt;\u003c/span\u003e \u003cspan class=\"n\"\u003emodified_data_columns_\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eThis is very cool. I guess the data type won’t align up to 64bits, so you can cut it in half.\u003c/p\u003e\n\n\u003ch1 id=\"bulk-operations\"\u003eBulk operations\u003c/h1\u003e\n\n\u003cp\u003eAs per usual, bulk computation is the answer since memory is the bottleneck…\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003eIntroduced\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003eGroupVarInt\u003c/span\u003e \u003cspan class=\"n\"\u003eformat\u003c/span\u003e \u003cspan class=\"n\"\u003ethat\u003c/span\u003e \u003cspan class=\"n\"\u003eencodes\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"n\"\u003edecodes\u003c/span\u003e \u003cspan class=\"n\"\u003egroups\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e \u003cspan class=\"n\"\u003evariable\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003elength\u003c/span\u003e \u003cspan class=\"n\"\u003eintegers\u003c/span\u003e \u003cspan class=\"n\"\u003eat\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"mi\"\u003e5\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e17\u003c/span\u003e \u003cspan class=\"n\"\u003ebytes\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erather\u003c/span\u003e \u003cspan class=\"n\"\u003ethan\u003c/span\u003e \u003cspan class=\"n\"\u003eone\u003c/span\u003e \u003cspan class=\"n\"\u003einteger\u003c/span\u003e \u003cspan class=\"n\"\u003eat\u003c/span\u003e \u003cspan class=\"n\"\u003ea\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e \u003cspan class=\"n\"\u003eDecoding\u003c/span\u003e \u003cspan class=\"n\"\u003eone\u003c/span\u003e \u003cspan class=\"n\"\u003egroup\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e \u003cspan class=\"n\"\u003eintegers\u003c/span\u003e \u003cspan class=\"n\"\u003ein\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"k\"\u003enew\u003c/span\u003e \u003cspan class=\"n\"\u003eformat\u003c/span\u003e \u003cspan class=\"n\"\u003etakes\u003c/span\u003e \u003cspan class=\"o\"\u003e~\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"o\"\u003e/\u003c/span\u003e\u003cspan class=\"mi\"\u003e3\u003c/span\u003e\u003cspan class=\"n\"\u003erd\u003c/span\u003e \u003cspan class=\"n\"\u003ethe\u003c/span\u003e \u003cspan class=\"n\"\u003etime\u003c/span\u003e \u003cspan class=\"n\"\u003eof\u003c/span\u003e \u003cspan class=\"n\"\u003edecoding\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e \u003cspan class=\"n\"\u003eindividually\u003c/span\u003e \u003cspan class=\"n\"\u003evarint\u003c/span\u003e\u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"n\"\u003eencoded\u003c/span\u003e \u003cspan class=\"n\"\u003eintegers\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\n\n\u003cspan class=\"n\"\u003egroupvarint\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecc\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"nf\"\u003eDecodeGroupVar\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"k\"\u003econst\u003c/span\u003e \u003cspan class=\"kt\"\u003echar\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003euint32\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003edest\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003eassert\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003egroupvar_initialized\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003eassert\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eN\u003c/span\u003e \u003cspan class=\"o\"\u003e%\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e \u003cspan class=\"o\"\u003e==\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"k\"\u003ewhile\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003euint8\u003c/span\u003e \u003cspan class=\"n\"\u003etag\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\n \u003cspan class=\"n\"\u003euint8\u003c/span\u003e\u003cspan class=\"o\"\u003e*\u003c/span\u003e \u003cspan class=\"n\"\u003elenptr\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026amp;\u003c/span\u003e\u003cspan class=\"n\"\u003egroupvar_table\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003etag\u003c/span\u003e\u003cspan class=\"p\"\u003e].\u003c/span\u003e\u003cspan class=\"n\"\u003elength\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n\n\u003cspan class=\"cp\"\u003e#define GET_NEXT \\\n do { \\\n uint8 len = *lenptr; \\\n *dest = UNALIGNED_LOAD32(p) \u0026amp; groupvar_mask[len]; \\\n dest++; \\\n p += len; \\\n lenptr++; \\\n } while (0)\n\u003c/span\u003e \u003cspan class=\"n\"\u003eGET_NEXT\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eGET_NEXT\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eGET_NEXT\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"n\"\u003eGET_NEXT\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"cp\"\u003e#undef GET_NEXT\n\u003c/span\u003e\n \u003cspan class=\"n\"\u003eN\u003c/span\u003e \u003cspan class=\"o\"\u003e-=\u003c/span\u003e \u003cspan class=\"mi\"\u003e4\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n \u003cspan class=\"p\"\u003e}\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"n\"\u003ep\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch1 id=\"cls-that-demonstrate-multiple-techniques\"\u003eCLs that demonstrate multiple techniques\u003c/h1\u003e\n\n\u003cp\u003eThis section is on seeing how a combination of techniques can be used to optimize a small part of a program and what to expect to the overall program.\u003c/p\u003e\n\n\u003cp\u003eFor example, one speeds up GPU allocator by 40% using less bytes, caching aligning, caching and commenting out logging results in 2.9% speedup in end to end\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSpeed up low level logging in Google Meet application code.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis was changing logging state from vector to static array of size 4, resulting in 50% boost for logging, which might be pretty common call\u003c/p\u003e\n\n\u003cp\u003eI think all of these require very deep insights into what the program is doing and where the program is spending its time.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWe found a number of performance issues when planning a switch from on-disk to in-memory index serving in 2001. This change fixed many of these problems and took us from 150 to over 500 in-memory queries per second (for a 2 GB in-memory index on dual processor Pentium III machine).\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is back in the day. Most likely applies still to personally written code, but doesn’t apply nearly as much these days as most people know the general optimizations and search was just becoming avaliable!\u003c/p\u003e\n\n\u003ch1 id=\"further-reading\"\u003eFurther reading\u003c/h1\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eUnderstanding Software Dynamics by Richard L. Sites. Covers expert methods and advanced tools for diagnosing and fixing performance problems.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGood book.\u003c/p\u003e","summary":"\u0026lt;!–","date_published":"2026-01-26T12:00:00+00:00","date_modified":"2026-01-26T12:00:00+00:00","author":{"name":""},"tags":["Rambling"]},{"id":"https://maknee.github.io/blog/2025/Maybe-Consider-Putting-Cutlass-In-Your-CUDA-Kernels","url":"https://maknee.github.io/blog/2025/Maybe-Consider-Putting-Cutlass-In-Your-CUDA-Kernels/","title":"Maybe consider putting “cutlass” in your CUDA/Triton kernels","content_html":"\u003ch1 id=\"motivation\"\u003eMotivation\u003c/h1\u003e\n\n\u003cp\u003eSo I was browsing Hacker News and came across this interesting post: \u003ca href=\"https://news.ycombinator.com/item?id=45458948\"\u003eFp8 runs ~100 tflops faster when the kernel name has “cutlass” in it\u003c/a\u003e.\u003c/p\u003e\n\n\u003cp\u003eThis was from Triton tutorial where someone noticed that adding “cutlass” to their kernel name gave them an additional 100-150 TFLOPs. That’s a huge improvement just from… a name?\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/original1.png\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eMentions 100 TFLOPs improvement\n \n (Image source: \u003ca href=\"https://github.com/triton-lang/triton/pull/7298\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eGithub pull\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/original2.png\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eMentions 150 TFLOPs improvement by renaming triton kernels to add cutlass\n \n (Image source: \u003ca href=\"https://github.com/triton-lang/triton/pull/7298\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eGithub pull\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eWell, I got a bit curious and wanted to why this happens.\u003c/p\u003e\n\n\u003ch1 id=\"so-what-exactly-is-this\"\u003eSo… what exactly is this?\u003c/h1\u003e\n\n\u003cp\u003eInstead of writing your kernel like this:\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003e__global__\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003eadd\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003esum\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003en\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ey\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003en\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003esum\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ey\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eYou add “cutlass” to the name:\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003e__global__\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003eadd_cutlass\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003esum\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003en\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003ey\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"o\"\u003e\u0026lt;\u003c/span\u003e \u003cspan class=\"n\"\u003en\u003c/span\u003e\u003cspan class=\"p\"\u003e;\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"o\"\u003e++\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n \u003cspan class=\"n\"\u003esum\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ex\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ey\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e];\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eand \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eptxas\u003c/code\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eIf you need some background on the CUDA compilation toolchain, refer to the \u003ca href=\"#nvidia-toolchain-background\"\u003esection on nvidia toolchain background\u003c/a\u003e\u003c/span\u003e will perform an additional pass that can add performance to the generated code.\u003c/p\u003e\n\n\u003cp\u003eThe rest of this blog will show benchmarks, explain the optimizations, and discuss when to use this trick. But I also want to highlight something broader: if you’re working at the high level (CUDA, Triton, PyTorch), you’re still at the mercy of what the backend compilers decide to do. In this case, ptxas (a black box) is making optimization decisions based on your kernel’s name\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWith the recent release of \u003ca href=\"https://docs.nvidia.com/cuda/tile-ir/sections/introduction.html\"\u003eTileIIR\u003c/a\u003e, there’s still plenty of magic happening under the hood. \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileiras\u003c/code\u003e is also a black box, so we could easily see a similar “cutlass” trick emerge there too\u003c/span\u003e.\u003c/p\u003e\n\n\u003cp\u003e\u003ca href=\"#so-what-is-it-doing\"\u003eIf you want to skip to TLDR of the optimization, click here\u003c/a\u003e\u003c/p\u003e\n\n\u003ch2 id=\"a-cutlass-example\"\u003eA cutlass example\u003c/h2\u003e\n\n\u003cp\u003eHere’s an example graph showing cutlass benchmarks with and without this optimization (where \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ebaseline/cutlass_on\u003c/code\u003e enables the optimization and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecutlass_off\u003c/code\u003e disables it):\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/main_example.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput of various cutlass examples\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eIn particular, the \u003ca href=\"https://docs.nvidia.com/cutlass/media/docs/cpp/cute/0x_gemm_tutorial.html#sgemm-2-cu\"\u003eCuTE sgemm2.cu\u003c/a\u003e \u003ca href=\"https://github.com/NVIDIA/cutlass/blob/v4.3.0/examples/cute/tutorial/sgemm_2.cu\"\u003eexample\u003c/a\u003e sees a 20% drop in performance without the cutlass optimization!\u003c/p\u003e\n\n\u003cp\u003eAnother thing immediately obvious is that this optimzation doesnt always increase performance.\u003c/p\u003e\n\n\u003ch1 id=\"benchmarks\"\u003eBenchmarks\u003c/h1\u003e\n\n\u003cp\u003eBelow are sections you can expand to see various benchmarks running on an RTX 3090 and H100. Each result is aggregated from 5 benchmark runs.\u003c/p\u003e\n\n\u003cp\u003eBenchmarks include 15+ projects, covering popular ones like PyTorch, Flash Attention 2/3, Cutlass, llama.cpp.\u003c/p\u003e\n\n\u003cp\u003eSome highlights:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003eRunning llama.cpp on RTX 3090 with gpt-oss-20b shows a 1%+ performance increase\u003c/li\u003e\n \u003cli\u003eFlash Attention 2 on RTX 3090/H100 without the optimization decreases performance by up to 1%\u003c/li\u003e\n \u003cli\u003eTriton on RTX 3090 generally shows no performance change from the optimization\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eNote: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ebaseline\u003c/code\u003e doesn’t change anything. \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecutlass_on\u003c/code\u003e enables the optimization and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecutlass_off\u003c/code\u003e disables it (if the application uses \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecutlass\u003c/code\u003e, for example Flash Attention 3):\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see 3090 benchmarks\u003c/summary\u003e\n\n \u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n \u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n \u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n \u003cdiv id=\"benchmark-3090-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003ctable id=\"benchmark-3090-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n GPU\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Benchmarks\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRTX 3090 (Ampere)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ebitsandbytes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecandle\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecutlass\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eflash_attn2\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col5\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eflashinfer\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col6\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eggml\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col7\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eliger\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col8\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellamacpp\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col9\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ellmc\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col10\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003emojo\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col11\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003enccl\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col12\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003epytorch\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col13\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003esageattention\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col14\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003esgemm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col15\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003esglang\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col16\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etilus\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col17\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etinygrad\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col18\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etorchao\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col19\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003etriton\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col20\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eunsloth\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-3090-table-row0-col21\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003evllm\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/bitsandbytes_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/candle_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/cutlass_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/flash_attn2_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/flashinfer_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/ggml_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/liger_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/llamacpp_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/llmc_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/mojo_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/nccl_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/pytorch_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/sageattention_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/sgemm_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/sglang_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/tilus_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/tinygrad_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/torchao_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/triton_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/unsloth_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/ampere/vllm_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see H100 benchmarks\u003c/summary\u003e\n\n \u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n \u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n \u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n \u003cdiv id=\"benchmark-h100-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003ctable id=\"benchmark-h100-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n GPU\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Benchmarks\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eH100 (Hopper)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ebitsandbytes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ecutlass\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003edeepep\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003edeepgemm_tflops\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col5\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eflash_attn2\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col6\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eflash_attn3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"benchmark-h100-table-row0-col7\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eflashinfer\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/bitsandbytes_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/cutlass_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/deepep_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/deepgemm_tflops_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/flash_attn2_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/flash_attn3_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/benchmarks/hopper/flashinfer_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch1 id=\"so-what-has-it-changed\"\u003eSo what has it changed?\u003c/h1\u003e\n\n\u003cp\u003eSo, I’ve added a godbolt reference for people to see the difference. I’m using some parts of \u003ca href=\"https://github.com/siboehm/SGEMM_CUDA/blob/master/src/kernels/9_kernel_autotuned.cuh\"\u003eSGEMM_CUDA\u003c/a\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eIf you haven’t checked it out, it’s \u003ca href=\"https://siboehm.com/articles/22/CUDA-MMM\"\u003ea great blog\u003c/a\u003e on optimizing cuda matmul kernels by Simon Boehm\u003c/span\u003e as reference.\u003c/p\u003e\n\n\u003cp\u003eIn the NVCC compliation pipeline, cuda goes to ptx then ptx goes to sass. Let’s check verify where this optimization is applied (is it applied at the ptx or sass code)?\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/gpu_compilation.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eHigh level compilation overview for NVIDIA GPUs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eFirst let’s explore if the cuda to ptx has changed.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/cuda_to_ptx.svg\" style=\"width: 140%; margin-left: calc((100% - 140%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThere's no difference in the PTX!\n \n (Image source: \u003ca href=\"https://godbolt.org/z/bcfj8ovrc\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eGodbolt link\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eOnly the name has changed. The PTX instructions are identical.\u003c/p\u003e\n\n\u003cp\u003eSo let’s now check the the sass \u003ca href=\"https://godbolt.org/z/erc4e8M17\"\u003eGodbolt link\u003c/a\u003e:\u003c/p\u003e\n\n\u003cdiv style=\"\n width: 90vw;\n margin-left: calc(50% - 45vw);\n margin-right: calc(50% - 45vw);\n\"\u003e\n \u003ciframe src=\"https://godbolt.org/e#z:OYLghAFBqd5QCxAYwPYBMCmBRdBLAF1QCcAaPECAMzwBtMA7AQwFtMQByARg9KtQYEAysib0QXACx8BBAKoBnTAAUAHpwAMvAFYTStJg1DIAruiakl9ZATwDKjdAGFUtEywYgATKUcAZPAZMADl3ACNMYgkAVgA2UgAHVAVCOwYXNw9vROTUgQCg0JYIqK44y0xrWwEhAiZiAgz3Tx8rTBs02vqCApDwyJj4hTqGpqzWkZ7AvuKBstiASktUE2Jkdg4AUg0AQU2vAGZA5DcsAGpNg6dTcwB9YhNBPDYAOgRL7G293f2jqiwqGcnHIACp%2BHZCIRfX4A6ZnADS2AASsFsH5bsEdgBZbAQZhsBZnUwEAwKBS3X6/fGYaGHSpKWkHWFBBHI1HozE4vGsTCE6mMxx4KjQn6HZmYM4AdR2SOUQgAkgAtbBnA5eRnioHYeXogAi8oAahAsaQzsFCVBjYT9gAhM4Qc0AWi4hIA9PbzQsRTtXe7Hf6A4Gg8GQ6Gw%2BGIxGvr6zgb2kRiHgAF6YdBnYC0VBhMSOj4KBD1VNnTNMNMISoJSIKe38NYSvy6gDiLy4XgAHK6hCChC322dAmchBChF7djHIxPJ1PpyGvgRMCwEgZ55cnIECGcbSb%2B4JN8FTevN/CD7viKgAO61RNYHYnjdny8Ea%2BYG0fL63W5YABueHWH7OX6oHgaYlugABiZ4sI2bAsBAaAMMMZxUCWBAAFRnLeRICIhyGoEwaGbneCJEfuXxnORFGUVR1E0UhKHoTsCimrh%2BHoTaTFkbRXHcTuG6BEExBIhemGHvxkQZLenE8dJlGiQwAlCeeNpEWJxAZDa1oAOyvrs5ExqBGFnLY9Cmk%2BhgKEkShpuuqBnPmhboJxvwJMQTDACwTBnI8Z60LQnG1vah6oFQgKXLqZwaJcdrBaFXh2g%2BV7AZgOwXFcYWbliUVnDFFxxelCVPklOxaTpKUUSxBCSABqXhcQmDrpELmYAQtyiMMq7wThKGSKhHwQFJVH7LEezRDaECqYpKW2tlIWEuh8K5XaqkSWc6GSJs0S6l6o2RZtUUDYxG1jctrgpWti0RXNGUXRNwkXTFG3helX4vKo%2B26RRh2jeN8niadq1nFV00ugDW43b9gl3dND2bTVAEvAAnu9ZXkV9x0Qyt53TV4V1g9Nt3nlNcUzcKsPPS8SbI5RaM/QJmOAxdBy41i4MKVDxMw09BzhS955Uxcmm6px%2Bl4WmdrGZgpq0NM9TFkwCMrAQTmHC5bkeV5DA%2BX5H0BT9G45elkUHNFIUXQVz52qu6U2vCWUG3l3NnObSUaQLpWUXVDXEE1LVtUrVwVT1fVDexR205EimWxzs2g8ErN/bQdprY9202rtXNCx91Ge4IjV1b7TDtVcnUboHvUHLiA20SHYcE1HJtUFdcf4xjrhJ4DKdHen/ObIL0J9z8Y5%2BjOI%2Bj2P/rRu6kr1AkroKCYYSOueM9O5gwB4MMkSOmEmbIAA1kWHkECwJi0EhkF2QWdVphLCiT2c4%2BP0/Ua7POi7LjSVyHluRE2vuvFHiIpKbch5JT/1AVieUIJkRAOCFAmBVdZK7klEIOQP8AEoLQeA3cIIQE4OCG%2BXYH5vy/kwP%2BQCwEzguVQOsMkEFUAsCEDBag9FV7AG3BVdCdVgD/04UZBAdVSxIkwPPWgBAOJZxklI8iJc6J4QIoxU0si%2BHsVIIg6RPETCHmXt7RSpotG7h0QkDIajJEaOkgYjcBABGYCEReeUDBp7e30YeaxgjnCuAcU4hIJVlYHFVu5Ty3lXDazKrrSxZx0CoA3IbLKUSYlXCPFlW0tp4m%2BMkc5VygSNZawGuEw8xA4ZG3iqlJwUpIHQKRMkvKcViCjhRhRTJasgmaxCYg/Ju48BFKyl01cZxcHVNfHFPA9TaLcMyqNQp6FcE3UenDNG8TQYs2mkYxSANgFm3WagvGeUzHmNom42x6BFJeJXtM5ZxM8CPWRgNJp2Tgm%2BTySQe0ETkDdONkSUpUo4GVMGbaZAozBoqyyerB5oTqIdL4u8paXyQQEI%2BSk4ZgLs5rwIaNN50zm6XLmdbO%2Bo1FlsSxXaIxGR1lEs%2BehTBf9FrqP2VRQ5pYMinO9gDOFsy9rG29JRO5oLWmPMkZCp20KhV9OAfAqpCKak2jqYgnlLTcl7MFW82JHzlWJLAeKv5cUAW0rlTktpezyKCt6Y7Yp/ZYWZUlUMm0IzaXkT1WCu1SFnkQAidoYV7q%2Blwq1TabQyKuIMuOSI0%2B4iw5TP6RcpaV0IAasqay80F0MX9PJdoHFDtM4NOkuMo64aZn4zmVw1FR0k1sumqmjl7sBYZt7tWoeD9n4NsbbmOtWImADiMY6YygRgBnAPsQIIZ8IBH2QOWasChgALhYLlLwfaB1nC4JFLwXgsIMBoD2gKCgWC3AAJwaCYPU8cTaj3jznAuJc%2BFP5rl3Og7%2B2CNw2yAXgjcYDYHiqIrg99pFDXkUPMEOQWJbgggABJImwDsXUUIK7vk/JgH8f5bhnA/LWdYgRpZBHIUBNM47J23CMV2owtxnhLjgthUurCPx1WGImGw/5JLfrpQxmiyjyP3BEYVGjCHlJOsYzxuRrFEOsao7%2BFqCGnCmMzbxyTP7dxPrNEReE6SyqyNeWs9KO8aF73lOgVQiN%2BbKcPMgUland6ae069G5H19OGJnlp1QcNA22demcKeMo5RKk%2BJyyzpGvLaJnkZx2RjbO5WiPaalU9zR6e8xE1ZF44aBe0850Lcdwtek898JTpHMCqBcj55BFTkRwxjSzSlCb3QxtcwqZUrLivJrJeK1LlaS5ZZyxEqlyzHYbKnvliVjXMvZcKa17ZzcOvJe%2BZqtLnErNWOAsyhI9mbGlkc3Z/YIXpSykqx53rCENwRMDUyxxK90q2HQLN4L9oqWjbhQ1ybUXXELeOfYg7LKjszae3NsrF3EtXYs2VD89lr7/gqhhPFNowZsVthWzif2r6pkByhTcIObax2uRNrz23ct8QhpNeb7iluJYgEj90khrto8QhEk6tApqOwc2Z5bXgQsE4WkTknGXtvNYGwUi8iUbyFb/QB4DoHwNCABsTxLNsfsyNu50rHF5LbU/u3jsrYXAYs8l%2Bj8nrdE446OXjlbSXEvE8i2z/rGOnZc8KlgOX4U%2BeAZA2BiD%2BPleG9R2E%2BHgbhGiNDaNMVcbzl1d98mnFjte5pxrfzIH2bvfdeq0Hp62lIp9zS0a%2BH3C0U2ljQVzFse3YJ6Fi7tXXV5EMR3nveZF04JrLYliK6tsk8rsL/xm0Je4b1yJKSwl4fWFOGb%2BlMp00K%2BxarxdGL551nV4BuSwzrhY7D789PkrEvnWFNdYeMIpeVV2jX18lfggFi183%2Bv9NR5FOUVAvQqCMFVzoL/qaB9ZvHzPkws7S3fUdgl9v%2B/4icnge34kRJqTPGBMIkmudG/%2BABDGdcJ4dMbcquFEf2CMDANg92CgEAsB5E1CtCCg5%2BTCC4V%2B24N%2BgC5SpoL65Sb6Y2lSxB2y24F2poH6geFc/U364ypoqepk92HuIaTEP%2BCOpoI%2BvBc%2BtA4mBy7Bj23ibB7i%2B23iaBGEPeR%2B4ude5ETetAh%2Br4jsSO6E8Kla5E8BiBgaKBsBYeg8ZUDqfK4KS%2BLyBSwqhSoq3WPqMqGSwKzS%2Bq/KDSSqwqaqZSGePWNK2qJ%2BVEfCLgJgCS4U3eyhmylKVBE%2BiaWyWC/M3Kjh9yph7SLq5OwqJqZSAyVqtotq36JhCqYBgqnqpqWURRGRmhZwnqR%2BouIeTqgc1U%2BU9Uuc3s%2BcrUhc/sTgZcwc9GQKsQgRSs307uohZyEaN0TcNK3R3Ee2nib28aF05aW0XcKOWhTGUufECW6UEAuaka/Y0aXh8a1oxMJaKacRNEL0dOqh4U7uwaYiIOwE5xEUSxTqL0CMNK8u7iHBNxR0dxF0XAjx3RL0SYrxlx7B1xXuNq6xxM6okO/xLwo%2BKSbxRyHxYJ3x00BwfxYB5EOc84zRzUrRRcHR3U5clcExjSXgvRisYcgx54p2fu%2BMYxKSJJwhEh0x3isxZancO02eX4Jx3KA8DShhewfJ98LgiB%2BcEogQwwhgtg%2BEaQ/YWATwogZ8RAdkE6LAU6s6lQ86Gg9oOwsQGg%2BproQGC62pCQ%2BECA9SWW2JDA06Tg%2Bwy6H4GYWYYg/4H4BgjwI6twYQKwDA6A5ItwEArYbYhIFCaYiIKIaIGI2IuI2GaptwGptAtwC6CwjBDSzGReGESi3mKizEXeREsmd6CIfhMZW6eGdA3ahG78q4gZpo1Z868QZwsQ0gDZTZTZbYpoTZgZr%2Bt%2BpoYmZw24/8CmNygsHASwtAnA0QvAngHAWgpAqAnAwIuoKUtpUqdkKwdYuUBwPApABAmgI5Swe83gXALwBwkgsQsQ0QXAmk%2Bp26mksQBwmkmk%2BgnAkgk5u5s5nAvACgIAGg25u5SwcAsASA%2BAIU5AlA/AggIgYg7AUgMgggigKg6g05vAtACAX5PgKFCgwFVABACMlYIAj5GFlprkqACQ1QU5Q4kI9oaAi4QRZCu6FgdkW69FCwvAxAqFIABwpAbFmFQo2FuF7Aj53FRFTAJFZFnAFFwuJGNF84O6e6pom6sl%2B6I5T5HAE5pAU5M5c5HAuovFsYeAmA54kQZoBoTgZSXABwLwGgLYD89QI6YUCl9FD8aGjR/AD8CgawjogQjopFdmD8AAGj5YXLlLEAAI4mDRJRSOgADyaJZJYVEVHyX41YwQJlZlFlVlXANlaw7w3MDle6Tl0wgQrljo7lyAnlDA3lBAvljoAVVVQVQ08V/sNo0VsVoV4VTVD8uowIYIw4vAO5SFCwSw5YpYAwqBpAB5mk0QLwkgmku6hwZ5kgkgJ50gY5HAL56lb5Wln535v5A1pAAFUAMAiAIAJC6woFUlCQdAkQwQPInAKVpl866VLYvAp1GwygIIflVFDCCQtFilLFvg%2BACYwEeg4Fwgip0F0goN8Fagb5ugnFbQHQ9gEAjgYwngXAvgPpvQRQJQeglkeQ6QrgzQuNuQZFWN/QpQFQVQnQkwqNegCNZFXQDQZNswFNkpowhNWQ6NbNUwhQ5NEgSwT4mA9U6A35o545r5SF752lsGpCelBlRl91aVll1lEAuAhAzyvw6NQI31V11hhwTMfVf5%2B5IA0QmkLwbYkgN5sQbYbYepbYmkttGg0QKl61LA3gztGlvAW1lgO1/VWg/5R1EAQFMtZ1FAF1utN1bAd1qVj1ytW5r1nA71n1F1v1zFvAqY6t14INsgkF4gMFUNSgMNktugrQlQ8YaQDgPptN6N/g0w2NcwOQKQZF1djd%2BNzNONXNZdiNDAjNjQHNaNlN5dNQkw7dDd3NLd3No9pQAtdUwtotKlalntUtuoIdEoBo%2BlhlhSitsdGV9oatCYG5WtLgi4utG5OMhtA1Q1Ryo1Yta1vAbt9OG1kt3tX5P5ftylB1QdKAOt9AZAYd1Fl1v9IADAX4yAyA5lGgJg6NNAYiVYlAYQb5YQgQ9QCMnAW5SDzAxACMUVYQ2g8YaDvA1FbAggUVDAtAqDktWAYQJgwATgYgtAX53AL1C4hgGYGwM5%2BAdUHQX4Iib5WW7QtFBD5Aucq1M50sYQrkWDLgWAb5hUbtTDpAPDxAXpSgK9HkRgaGoAe1yEbkCg69BlUVlYU5W5oNudENsF8ghdiFM5Jd%2BgrDKAZgFg4jX5kASwolaQjDjotlOVuoeVnkjozlRVtkJVHlXlPl5ENVgV1YdpLVdpXtSjz4LjY1pphQt1HA295lcd9oDAtkxIpICg/19NFdyNVd/degtdvNLNxNTdaQLdeNpNddfNndVNw93QE9XdDNI9jTVTXNNNZTvT3QU9/Nywqw6wwzq1i9m10dD1mTu9eIuTQR%2BT/179g1pAw1WAUQY1q1rt7tT9mlH5Ptb9f5%2B1gdSAADut515zQDIDYDEDUDfAdA2JX5EACDktGDKDQj7zWDODeDNgQjRDjABApD5Db5VDNDdDvkjD8dLDGj7DL1eAXDtgPDjDM5/DxIGwW5DUojyFeAEjKD0jcL25iY8jW5SjKjmAajrDmjxzOjwAejG9hjjAQjpj4NEgkNsg0N1jOgHFdjRgDj5g%2BguLSTbjpFHjnAXj2V9lTF%2BVAThVq6wTpV5VlV1VtVqg9VS6sTS6nV3V4IkI8TkQiT8ASwKTIQaTxlMzT1mVcEizhcBTn5HTxTKN/TGN6AQz6N9TtTzrHr%2BQ3THdg93dvd7TLTPdXTlTfr49zrk9vrcwSwCga5YzLoC9Et%2BzHA5rStczeTtryzRtaz19mzt9Ozj9S9L9vtObB5Fll5S1Xgp5J50QbY26S60Qztq1BwybXtBzKzt9XgbbUtF9/tSwSjKQ9gkgQAA%3D%3D%3D\" style=\"\n width: 100%;\n height: 800px;\n border: 0;\n display: block;\n \" loading=\"lazy\"\u003e\n \u003c/iframe\u003e\n\u003c/div\u003e\n\n\u003cp\u003eClearly something has changed!\u003c/p\u003e\n\n\u003cp\u003eTwo common changes we can see are:\u003c/p\u003e\n\n\u003c!-- https://godbolt.org/z/7TKvhv4Gj --\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/instruction_selection.svg\" style=\"width: 160%; margin-left: calc((100% - 160%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThe optimization now uses IMAD instead of HFMA2.MMA to move constants\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eWe can see that \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eIMAD\u003c/code\u003e is used instead of \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eHFMA2.MMA\u003c/code\u003e for moving constants, which is neat!\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eBy using \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eIMAD\u003c/code\u003e, we can use the \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eFP32\u003c/code\u003e units. Refer to \u003ca href=\"#h100-sm-diagram\"\u003eH100 SM Diagram\u003c/a\u003e\u003c/span\u003e.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/instruction_reordering.svg\" style=\"width: 140%; margin-left: calc((100% - 140%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eEnable interleaving LDS and FFMA\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eWe can see that \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eLDS\u003c/code\u003e interleaved instead of being stacked together\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis should be able to increase instruction level parallelism\u003c/span\u003e\u003c/p\u003e\n\n\u003cp\u003eOne thing that the disassembly doesn’t show is the register pressure. This optimization may increase register pressure:\u003c/p\u003e\n\n\u003cdiv class=\"language-bash highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003ecuobjdump \u003cspan class=\"nt\"\u003e--dump-resource-usage\u003c/span\u003e baseline.cubin\n\n Resource usage:\n Common:\n GLOBAL:0\n Function sgemm_kernel_10:\n REG:188 STACK:0 SHARED:17408 LOCAL:0 CONSTANT[0]:564 TEXTURE:0 SURFACE:0 SAMPLER:0\n\ncuobjdump \u003cspan class=\"nt\"\u003e--dump-resource-usage\u003c/span\u003e cutlass.cubin\n\n Resource usage:\n Common:\n GLOBAL:0\n Function cutlass_sgemm_kernel_9:\n REG:214 STACK:0 SHARED:17408 LOCAL:0 CONSTANT[0]:564 TEXTURE:0 SURFACE:0 SAMPLER:0\n\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eRegister usage increased from \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e188\u003c/code\u003e to \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e214\u003c/code\u003e, a \u003ccode class=\"language-plaintext highlighter-rouge\"\u003e13%\u003c/code\u003e increase in register usage. However, this isn’t always the case. I’ve seen other examples not affect register pressure and even decrease register pressure.\u003c/p\u003e\n\n\u003cp\u003eBelow is a table of the different instructions that have changed for this kernel:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"sass-diff-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003ctable id=\"sass-diff-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Mnemonic\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Baseline\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n CUTLASS\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Δ\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eIMAD.MOV.U32\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e37\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e+37\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHFMA2.MMA\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e-5\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLEA\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e15\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e-13\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eIMAD.SHL.U32\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row3-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e+10\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCS2R\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e75\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e64\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row4-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e-11\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMOV\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e8\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row5-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e-8\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eIMAD\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e8\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row6-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e+8\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eULDC.64\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row7-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e-3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row8-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFFMA\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row8-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e787\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row8-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e801\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sass-diff-table-row8-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e+14\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003ch1 id=\"so-what-is-it-doing\"\u003eSo… what is it doing?\u003c/h1\u003e\n\n\u003cp\u003eSo far, we’ve dug into specifics. The higher optimization seems to most likely do the following:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003eInstruction selection - use f32 units for moving constants\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eMoving constants from registers isn’t in the hot path, but it’s a simple to see example!\u003c/span\u003e registers\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eBut wait there’s more! I didn’t show it in this blog in detail, but you can see some IMADs replacing instructions\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eInstruction reordering - mix memory loads with math\u003c/li\u003e\n \u003cli\u003eInfluence register pressure - may increase the number of registers used to achieve reodering\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdiv class=\"language-md highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003eWhen ptxas sees matrix operations (MAD/MMA):\n\n Instruction selection:\n HFMA2.MMA,MOV -\u0026gt; IMAD \n\n Instruction reordering:\n LDS spread across FMMA\n\n As a Side effect:\n May increase register pressure\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch1 id=\"when-should-you-apply-this-optimization\"\u003eWhen should you apply this optimization?\u003c/h1\u003e\n\n\u003cp\u003eWith kernel writing, it’s tricky to say when you absolutely should and shouldn’t use this optimization. The optimization seems to increase ILP at the cost of register pressure\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWon’t increase register pressure in some cases!\u003c/span\u003e. Always benchmark to ensure the performance is good\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI’ve seen the optimization not affect performance on some cards while affecting others significantly\u003c/span\u003e.\u003c/p\u003e\n\n\u003ch1 id=\"how-to-apply-this-to-triton\"\u003eHow to apply this to triton\u003c/h1\u003e\n\n\u003cdiv class=\"language-python highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"kn\"\u003eimport\u003c/span\u003e \u003cspan class=\"n\"\u003etorch\u003c/span\u003e\n\u003cspan class=\"kn\"\u003eimport\u003c/span\u003e \u003cspan class=\"n\"\u003etriton\u003c/span\u003e\n\u003cspan class=\"kn\"\u003eimport\u003c/span\u003e \u003cspan class=\"n\"\u003etriton.language\u003c/span\u003e \u003cspan class=\"k\"\u003eas\u003c/span\u003e \u003cspan class=\"n\"\u003etl\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003edef\u003c/span\u003e \u003cspan class=\"nf\"\u003erename_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eproxy\u003c/span\u003e\u003cspan class=\"p\"\u003e):\u003c/span\u003e\n \u003cspan class=\"k\"\u003ereturn\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003ecutlass_kernel\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e# will convert \"my_kernel\" -\u0026gt; cutlass_kernel\n\u003c/span\u003e\u003cspan class=\"nd\"\u003e@triton.jit\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nb\"\u003erepr\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"n\"\u003erename_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003edef\u003c/span\u003e \u003cspan class=\"nf\"\u003emy_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e \u003cspan class=\"n\"\u003etl\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003econstexpr\u003c/span\u003e\u003cspan class=\"p\"\u003e):\u003c/span\u003e\n \u003cspan class=\"k\"\u003epass\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e# compile and extract ptx\n\u003c/span\u003e\u003cspan class=\"n\"\u003emy_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e[(\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e,)](\u003c/span\u003e\u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"o\"\u003e=\u003c/span\u003e\u003cspan class=\"mi\"\u003e32\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"n\"\u003edev\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003etorch\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ecuda\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003ecurrent_device\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\n\u003cspan class=\"n\"\u003ekernel_cache\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003emy_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003edevice_caches\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003edev\u003c/span\u003e\u003cspan class=\"p\"\u003e][\u003c/span\u003e\u003cspan class=\"mi\"\u003e0\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\u003cspan class=\"n\"\u003ecompiled\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nf\"\u003enext\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nf\"\u003eiter\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ekernel_cache\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003evalues\u003c/span\u003e\u003cspan class=\"p\"\u003e()))\u003c/span\u003e\n\u003cspan class=\"n\"\u003eptx\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003ecompiled\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003easm\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003eptx\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e# print the kernel name from PTX\n\u003c/span\u003e\u003cspan class=\"nf\"\u003eprint\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"sh\"\u003e'\u003c/span\u003e\u003cspan class=\"se\"\u003e\\n\u003c/span\u003e\u003cspan class=\"sh\"\u003e'\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003ejoin\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eptx\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003esplitlines\u003c/span\u003e\u003cspan class=\"p\"\u003e()[:\u003c/span\u003e\u003cspan class=\"mi\"\u003e20\u003c/span\u003e\u003cspan class=\"p\"\u003e]))\u003c/span\u003e\n\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eIt will show\u003c/p\u003e\n\n\u003cdiv class=\"language-c highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"c1\"\u003e//\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e// Generated by LLVM NVPTX Back-End\u003c/span\u003e\n\u003cspan class=\"c1\"\u003e//\u003c/span\u003e\n\n\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eversion\u003c/span\u003e \u003cspan class=\"mi\"\u003e8\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"mi\"\u003e7\u003c/span\u003e\n\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003etarget\u003c/span\u003e \u003cspan class=\"n\"\u003esm_86\u003c/span\u003e\n\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eaddress_size\u003c/span\u003e \u003cspan class=\"mi\"\u003e64\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e// .globl cutlass_kernel // -- Begin function cutlass_kernel\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e// @cutlass_kernel\u003c/span\u003e\n\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003evisible\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eentry\u003c/span\u003e \u003cspan class=\"n\"\u003ecutlass_kernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\n \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eparam\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eu64\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eptr\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eglobal\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ealign\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e \u003cspan class=\"n\"\u003ecutlass_kernel_param_0\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\n \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eparam\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eu64\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eptr\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003eglobal\u003c/span\u003e \u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"n\"\u003ealign\u003c/span\u003e \u003cspan class=\"mi\"\u003e1\u003c/span\u003e \u003cspan class=\"n\"\u003ecutlass_kernel_param_1\u003c/span\u003e\n\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch1 id=\"how-to-apply-this-to-ptxas\"\u003eHow to apply this to ptxas\u003c/h1\u003e\n\n\u003cp\u003eA universal patch to ptxas (which most frameworks invoke) is to just replace \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ecutlass\u003c/code\u003e in the binary with something else.\u003c/p\u003e\n\n\u003cp\u003eHere’s how I do it:\u003c/p\u003e\n\n\u003cdiv class=\"language-python highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003einput_path\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003e/usr/local/cuda/bin/ptxas\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\n\u003cspan class=\"n\"\u003eoutput_path\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003eptxas_no_cutlass\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003ewith\u003c/span\u003e \u003cspan class=\"nf\"\u003eopen\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003einput_path\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003erb\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003eas\u003c/span\u003e \u003cspan class=\"n\"\u003ef\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003eblob\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"nf\"\u003ebytearray\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003ef\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003eread\u003c/span\u003e\u003cspan class=\"p\"\u003e())\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e# We expect exactly \"cutlass\" inside ptxas.\n\u003c/span\u003e\u003cspan class=\"n\"\u003etarget\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"sa\"\u003eb\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003ecutlass\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\n\u003cspan class=\"n\"\u003eoff\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"n\"\u003eblob\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003efind\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003cspan class=\"k\"\u003eassert\u003c/span\u003e \u003cspan class=\"n\"\u003eoff\u003c/span\u003e \u003cspan class=\"o\"\u003e!=\u003c/span\u003e \u003cspan class=\"o\"\u003e-\u003c/span\u003e\u003cspan class=\"mi\"\u003e1\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003eptxas did not contain the cutlass marker!\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\n\n\u003cspan class=\"c1\"\u003e# Overwrite: c u t l a s s → ff ff ff ff ff ff ff, so that strstr(\"0xFF\") since kernel names contains ascii\n\u003c/span\u003e\u003cspan class=\"k\"\u003efor\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e \u003cspan class=\"ow\"\u003ein\u003c/span\u003e \u003cspan class=\"nf\"\u003erange\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"nf\"\u003elen\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e)):\u003c/span\u003e\n \u003cspan class=\"n\"\u003eblob\u003c/span\u003e\u003cspan class=\"p\"\u003e[\u003c/span\u003e\u003cspan class=\"n\"\u003eoff\u003c/span\u003e \u003cspan class=\"o\"\u003e+\u003c/span\u003e \u003cspan class=\"n\"\u003ei\u003c/span\u003e\u003cspan class=\"p\"\u003e]\u003c/span\u003e \u003cspan class=\"o\"\u003e=\u003c/span\u003e \u003cspan class=\"mh\"\u003e0xFF\u003c/span\u003e\n\n\u003cspan class=\"k\"\u003ewith\u003c/span\u003e \u003cspan class=\"nf\"\u003eopen\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eoutput_path\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003ewb\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"k\"\u003eas\u003c/span\u003e \u003cspan class=\"n\"\u003ef\u003c/span\u003e\u003cspan class=\"p\"\u003e:\u003c/span\u003e\n \u003cspan class=\"n\"\u003ef\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003ewrite\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eblob\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\n\u003cspan class=\"nf\"\u003eprint\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"sa\"\u003ef\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"s\"\u003epatched \u003c/span\u003e\u003cspan class=\"sh\"\u003e'\u003c/span\u003e\u003cspan class=\"si\"\u003e{\u003c/span\u003e\u003cspan class=\"n\"\u003etarget\u003c/span\u003e\u003cspan class=\"p\"\u003e.\u003c/span\u003e\u003cspan class=\"nf\"\u003edecode\u003c/span\u003e\u003cspan class=\"p\"\u003e()\u003c/span\u003e\u003cspan class=\"si\"\u003e}\u003c/span\u003e\u003cspan class=\"sh\"\u003e'\u003c/span\u003e\u003cspan class=\"s\"\u003e at offset \u003c/span\u003e\u003cspan class=\"si\"\u003e{\u003c/span\u003e\u003cspan class=\"n\"\u003eoff\u003c/span\u003e\u003cspan class=\"si\"\u003e:\u003c/span\u003e\u003cspan class=\"c1\"\u003e#x\u003c/span\u003e\u003cspan class=\"si\"\u003e}\u003c/span\u003e\u003cspan class=\"sh\"\u003e\"\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch1 id=\"resolving-public-statements\"\u003eResolving Public Statements\u003c/h1\u003e\n\n\u003cp\u003eIn my opinion, there seems to be a lot of assumptions people are throwing out on the internet about this optimization. I want to clear some of that up.\u003c/p\u003e\n\n\u003cp\u003eOn the top of the \u003ca href=\"https://news.ycombinator.com/item?id=45458948\"\u003ehackernews post\u003c/a\u003e, it links to a response from a user about this optimization.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/unstable.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThis statement is incorrect; I have compiled many real world projects with this optimization on and off and they ran without failing (passing output asserts) on different cards.\u003c/p\u003e\n\n\u003cp\u003eAlso with \u003ca href=\"https://www.reddit.com/r/programming/comments/1nx3g70/fp8_runs_100_tflops_faster_when_the_kernel_name/\"\u003ea highly voted reddit comment\u003c/a\u003e\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/reddit.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThis explanation is really hard to understand. I’m guessing that the user is stating this trick uses NaNs/zeroes to optimize the program. It doesn’t use that. In fact, it tries to optimizes how registers are moved.\u003c/p\u003e\n\n\u003ch1 id=\"previous-mentions\"\u003ePrevious mentions\u003c/h1\u003e\n\n\u003cp\u003eThis was also mentioned before by \u003ca href=\"https://forums.developer.nvidia.com/t/how-does-bar-sync-defer-blocking-get-generated/245747\"\u003egrynet on the nvidia forums\u003c/a\u003e where he complained that the following kernel would generate different sass\u003c/p\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003e__global__\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003emykernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003elhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003erhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eres\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eK\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecutlass\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003egemm\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eGemmCoord\u003c/span\u003e \u003cspan class=\"n\"\u003eproblem_size\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"n\"\u003eK\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecompute_gemm_with_cutlass\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eres\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eproblem_size\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cdiv class=\"language-cpp highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"n\"\u003e__global__\u003c/span\u003e \u003cspan class=\"kt\"\u003evoid\u003c/span\u003e \u003cspan class=\"nf\"\u003emykernel\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003elhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003erhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003efloat\u003c/span\u003e \u003cspan class=\"o\"\u003e*\u003c/span\u003e\u003cspan class=\"n\"\u003eres\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"kt\"\u003eint\u003c/span\u003e \u003cspan class=\"n\"\u003eK\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003ecutlass\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003egemm\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eGemmCoord\u003c/span\u003e \u003cspan class=\"n\"\u003edummy\u003c/span\u003e\u003cspan class=\"p\"\u003e)\u003c/span\u003e \u003cspan class=\"p\"\u003e{\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecutlass\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003egemm\u003c/span\u003e\u003cspan class=\"o\"\u003e::\u003c/span\u003e\u003cspan class=\"n\"\u003eGemmCoord\u003c/span\u003e \u003cspan class=\"n\"\u003eproblem_size\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003eM\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"n\"\u003eN\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e\u003cspan class=\"n\"\u003eK\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n \u003cspan class=\"n\"\u003ecompute_gemm_with_cutlass\u003c/span\u003e\u003cspan class=\"p\"\u003e(\u003c/span\u003e\u003cspan class=\"n\"\u003elhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003erhs\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eres\u003c/span\u003e\u003cspan class=\"p\"\u003e,\u003c/span\u003e \u003cspan class=\"n\"\u003eproblem_size\u003c/span\u003e\u003cspan class=\"p\"\u003e);\u003c/span\u003e\n\u003cspan class=\"p\"\u003e}\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eand \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eBAR.SYNC.DEFER_BLOCKING\u003c/code\u003e would be generated here instead of \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eBAR.SYNC\u003c/code\u003e (due to cutlass being added as part ofthe function signature)\u003c/p\u003e\n\n\u003cp\u003ePerhaps this was also a part of the optimization in previous versions of \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eptxas\u003c/code\u003e?\u003c/p\u003e\n\n\u003ch1 id=\"takeaway\"\u003eTakeaway\u003c/h1\u003e\n\n\u003cp\u003eSo, adding “cutlass” to your kernel name can give you 100+ TFLOPs or -20% FLOPS.\u003c/p\u003e\n\n\u003cp\u003eThe issue is two fold: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eptxas\u003c/code\u003e is a black box and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003esass\u003c/code\u003e is undocumented. It’s unlike other ecosystems. You can see the passes running through LLVM and x86/arm are documented.\u003c/p\u003e\n\n\u003cp\u003eWell, with this optimization… it helps some kernels, hurts others or change not much at all. Completely depends on your architecture and your specific code. What flies on an H100 might tank on a 5090 or B200, and you have no way to know until you run it.\u003c/p\u003e\n\n\u003cp\u003eSo what do you do? Benchmark it. Change the ordering in triton/cuda, see if PTX changes, check the SASS output. That’s the only way to know what \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eptxas\u003c/code\u003e actually did.\u003c/p\u003e\n\n\u003cp\u003eAnd this isn’t going away. \u003ccode class=\"language-plaintext highlighter-rouge\"\u003etileiras\u003c/code\u003e (the new TileIIR compiler) is also a black box. We may expect similar surprises like this moving forward.\u003c/p\u003e\n\n\u003ch1 id=\"appendix\"\u003eAppendix\u003c/h1\u003e\n\n\u003ch2 id=\"nvidia-toolchain-background\"\u003eNVIDIA toolchain background\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/gpu_compilation.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eHigh level compilation overview for NVIDIA GPUs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eNVIDIA’s toolchain works like this: \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eCUDA code\u003c/code\u003e is compiled by \u003cem\u003envcc\u003c/em\u003e into \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ePTX\u003c/code\u003e, an intermediate representation. Then \u003cem\u003eptxas\u003c/em\u003e takes that \u003ccode class=\"language-plaintext highlighter-rouge\"\u003ePTX\u003c/code\u003e and turns it into \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eSASS\u003c/code\u003e, the low-level instruction set the GPU runs\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eptxas and sass are both undocumented, so it may be a bit difficult to understand what’s going on\u003c/span\u003e.\u003c/p\u003e\n\n\u003ch2 id=\"h100-sm-diagram\"\u003eH100 SM Diagram\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-12-05/gh100.png\" width=\"50%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eH100 SM Diagram\n \n (Image source: \u003ca href=\"https://resources.nvidia.com/en-us-hopper-architecture/nvidia-h100-tensor-c\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eNVIDIA H100 GPU Whitepaper\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003ch2 id=\"changes\"\u003eChanges\u003c/h2\u003e\n\n\u003cp\u003e[12/16/2026] Thanks to @Firadeoclus on GPUMODE discord for pointing out that my original post mixes up \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eHMMA\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eHFMA2.MMA\u003c/code\u003e and how they move constants instead of zeroing.\u003c/p\u003e\n\n\u003ch1 id=\"citation\"\u003eCitation\u003c/h1\u003e\n\n\u003cp\u003eTo cite this article:\u003c/p\u003e\n\n\u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e@article{zhu2025cutlass,\n title = {Maybe consider putting \"cutlass\" in your CUDA/Triton kernels},\n author = {Zhu, Henry},\n journal = {maknee.github.io},\n year = {2025},\n month = {December},\n url = \"https://maknee.github.io/blog/2025/Maybe-Consider-Putting-Cutlass-In-Your-CUDA-Kernels/\"\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e","summary":"Motivation","date_published":"2025-12-15T06:00:00+00:00","date_modified":"2025-12-15T06:00:00+00:00","author":{"name":""},"tags":["CUDA"]},{"id":"https://maknee.github.io/blog/2025/3FS-Performance-Journal-3","url":"https://maknee.github.io/blog/2025/3FS-Performance-Journal-3/","title":"Network Storage and Scaling Characteristics of a Distributed Filesystem","content_html":"\u003ch1 id=\"series\"\u003eSeries\u003c/h1\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-1/\"\u003eAn Intro to DeepSeek’s Distributed File System\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-2/\"\u003eA Reality Check on DeepSeek’s Distributed File System Benchmarks\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-3/\"\u003eNetwork Storage and Scaling Characteristics of a Distributed Filesystem\u003c/a\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003c!--\n- [Theoretical Performance Limits of 3FS](/blog/2018/RTX-DXR-Path-Tracer-Host/)\n- [Benchmarking 3FS](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n- [Analysis of 3FS Benchmarks](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n- [Improving 3FS Performance](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n--\u003e\n\n\u003ch1 id=\"table-of-contents\"\u003eTable of Contents\u003c/h1\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003ca href=\"#the-benchmarking-pyramid\"\u003eThe Benchmarking Pyramid\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#network-baseline-benchmark\"\u003eNetwork Baseline Benchmark\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#benchmarking-for-modern-cluster\"\u003eStorage Baseline Benchmark\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#3fs\"\u003e3FS Performance Analysis\u003c/a\u003e\n \u003cul\u003e\n \u003cli\u003e\u003ca href=\"#scaling-block-size-5-nodes\"\u003eScaling Block Size\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#scaling-nodes\"\u003eScaling Number of Nodes\u003c/a\u003e\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003e\u003ca href=\"#wrapping-up\"\u003eWrapping up\u003c/a\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch1 id=\"refresher\"\u003eRefresher\u003c/h1\u003e\n\n\u003cp\u003eIn \u003ca href=\"/blog/2025/3FS-Performance-Journal-1/\"\u003emy first post\u003c/a\u003e, I introduced DeepSeek’s \u003ca href=\"https://github.com/deepseek-ai/3FS/tree/ee9a5cee0a85c64f4797bf380257350ca1becd36\"\u003e3FS distributed file system\u003c/a\u003e and performed a \u003ca href=\"/blog/2025/3FS-Performance-Journal-2/\"\u003ereality check in the second post\u003c/a\u003e. Now it’s time to see how 3FS performs in practice.\u003c/p\u003e\n\n\u003ch1 id=\"the-benchmarking-pyramid\"\u003eThe Benchmarking Pyramid\u003c/h1\u003e\n\n\u003cp\u003eBefore diving into results, let’s talk about the understanding software performance from a high level. If we imagine performance understanding as an onion, peeling off each layer onion reveals deeper insights\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eEach layer gives us a deeper understanding. Without starting at the top, the discovering insights may be difficult\u003c/span\u003e\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/increasing_difficulty.svg\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThe performance analysis pyramid: from theoretical limits to production\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eWe started with napkin math in the first post, performed reality checks in the second, and now we’re ready for the next layer: microbenchmarking.\u003c/p\u003e\n\n\u003ch2 id=\"why-microbenchmark\"\u003eWhy Microbenchmark?\u003c/h2\u003e\n\n\u003cp\u003eThink of microbenchmarking as testing individual components in isolation. Instead of running a complex workload that does everything at once, we test one specific operation repeatedly until we understand its exact performance characteristics. It’s like measuring only how fast a car accelerates in a straight line instead of timing a trip through city traffic where you can’t tell if slowdowns are from stop signs, traffic lights, or congested highways.\u003c/p\u003e\n\n\u003cp\u003eBut one might ask: why not jump straight to real workloads? Real workloads are messy. They mix reads, writes, different block sizes, and various access patterns. When something’s slow, is it the network? The disk? The software? That’s the challenge with macrobenchmarks and production workloads (the bottom layers of our pyramid). There’s too many variables at once. Microbenchmarks let us isolate each component and understand exactly where time is spent\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThey answer specific questions like: What’s the maximum throughput for sequential reads? How does latency change with queue depth? Where exactly does performance cliff when we increase parallelism?\u003c/span\u003e.\u003c/p\u003e\n\n\u003cp\u003eThese benchmarks build intuition at multiple levels: from raw hardware performance to how exactly 3FS performs. Once one recognize these patterns, one can have intuition on related applications may be slow and how to fix it\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis knowledge transfers across systems too – similar hardware will have similar characteristics regardless of the software running on top, and similar types of software (like filesystems) perform comparable operations\u003c/span\u003e.\u003c/p\u003e\n\n\u003cp\u003eIn my previous posts, I made several predictions about 3FS performance based on napkin math and reality checks. Now that I have actual microbenchmark data, I can see how accurate those predictions were or how terribly off I was.\u003c/p\u003e\n\n\u003ch2 id=\"what-were-measuring-and-why\"\u003eWhat we’re measuring and why\u003c/h2\u003e\n\n\u003cp\u003eIn this post, we’ll answer five key questions:\u003c/p\u003e\n\n\u003col\u003e\n \u003cli\u003e\u003cstrong\u003eWhat are the hardware limits?\u003c/strong\u003e – Local SSD and InfiniBand benchmarks establish our ceiling\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eHow does 3FS compare?\u003c/strong\u003e – Performance differences from local benchmarks and why they occur\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eIs 3FS hardware-specific?\u003c/strong\u003e – Does it require high-end hardware or work well on commodity clusters?\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003e\u003ca href=\"https://arxiv.org/pdf/2408.14158\"\u003eDeepSeek’s paper\u003c/a\u003e describes a cluster with NVMe SSDs and 200Gb/s InfiniBand. What happens with SATA SSDs and 25Gb/s networking?\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eHow does 3FS scale?\u003c/strong\u003e – Performance across different node counts and configurations\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eWhat knobs matter?\u003c/strong\u003e – Impact of block sizes, I/O patterns, and tuning parameters\u003c/li\u003e\n\u003c/ol\u003e\n\n\u003cp\u003eThis will start to build our intuition for how 3FS performs. The post includes many interactive graphs to explore the data yourself\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI’ll highlight the interesting patterns so you don’t drown in numbers, sometimes benchmarks reveal surprising behaviors\u003c/span\u003e.\u003c/p\u003e\n\n\u003ch1 id=\"single-node-benchmarking\"\u003eSingle Node Benchmarking\u003c/h1\u003e\n\n\u003cp\u003eBefore diving into 3FS performance, we need to understand how our clusters performs. This section establishes baseline performance for both network and storage using standard tools.\u003c/p\u003e\n\n\u003ch2 id=\"testing-environment\"\u003eTesting Environment\u003c/h2\u003e\n\n\u003cp\u003eI have two contrasting setups that tell an interesting story:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Older Cluster (18 Nodes)\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Modern Cluster (5 Nodes)\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNode Count\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e18\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eUse case\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eBudget cluster\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-performance cluster\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCPU\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10-core Intel E5-2640v4 (2017 era)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2×36-core Intel Xeon Platinum (2021 era)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMemory\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e64GB DDR4-2400\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e256GB DDR4-3200\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStorage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSATA SSD (480GB)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row4-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNVMe SSD (1.6TB PCIe 4.0)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 Gbps (3.25 GB/s)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Older Cluster (18 Nodes),Modern Cluster (5 Nodes)-row5-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e100 Gbps (12.5 GB/s)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe older cluster represents deployments using previous-generation hardware. The modern cluster represents somewhat current high-performance deployments. Comparing these reveals how 3FS performs across different hardware generations\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI don’t have access to a high-end cluster with many NVMe drives and newer NICs. I’d love to have the setup that the 3FS team uses, but I’m just a student without access to those types of clusters 😔\u003c/span\u003e. I’ll be referring these clusters as \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eold cluster\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003enew cluster\u003c/code\u003e.\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to see more detailed hardware specifications\u003c/summary\u003e\n\n \u003ctable class=\"datatable display compact cell-border row-border hover\"\u003e\n \u003cthead\u003e\n \u003ctr\u003e\n \u003cth\u003eComponent\u003c/th\u003e\n \u003cth\u003eOlder Cluster (18 Node Setup)\u003c/th\u003e\n \u003cth\u003eModern Cluster (5 Node Setup)\u003c/th\u003e\n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \u003ctr\u003e\n \u003ctd\u003eNode Count\u003c/td\u003e\n \u003ctd\u003e18\u003c/td\u003e\n \u003ctd\u003e5\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003eCPU\u003c/td\u003e\n \u003ctd\u003eTen-core Intel E5-2640v4 at 2.4 GHz\u003c/td\u003e\n \u003ctd\u003eTwo 36-core Intel Xeon Platinum 8360Y at 2.4GHz\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003eRAM\u003c/td\u003e\n \u003ctd\u003e64GB ECC Memory (4x 16 GB DDR4-2400 DIMMs)\u003c/td\u003e\n \u003ctd\u003e256GB ECC Memory (16x 16 GB 3200MHz DDR4)\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003eDisk\u003c/td\u003e\n \u003ctd\u003e\u003ca href=\"https://servak.com.ua/image/manual/SSD/SSD_240GB_2.5_6G_INTEL_DC_S3520_SERIES_SATA_Quick_Specs_Servak_2.pdf?srsltid=AfmBOoq8zg_-WF9Sop69GSohu_edCS2TGfP0pINVrR3IfPklqPNjLb5J\"\u003eIntel DC S3520 480 GB 6G SATA SSD\u003c/a\u003e (OS \u0026amp; Workload)\u003c/td\u003e\n \u003ctd\u003e\u003ca href=\"https://semiconductor.samsung.com/ssd/datacenter-ssd/sm883/mz7kh480hahq/\"\u003eSamsung 480GB SATA SSD\u003c/a\u003e (OS)\u003cbr /\u003e\u003ca href=\"https://dl.dell.com/manuals/all-products/esuprt_data_center_infra_int/esuprt_data_center_infra_storage_adapters/dell-poweredge-exp-fsh-nvme-pcie-ssd_users-guide7_en-us.pdf\"\u003eDell NVMe 1.6TB NVMe SSD (PCIe v4.0)\u003c/a\u003e (Workload)\u003c/td\u003e\n \u003c/tr\u003e\n \u003ctr\u003e\n \u003ctd\u003eNetwork\u003c/td\u003e\n \u003ctd\u003eMellanox ConnectX-4 25 GB NIC\u003cbr /\u003e(1.25 GB/s, only one physical port at 25 Gbps)\u003c/td\u003e\n \u003ctd\u003eDual-port Mellanox ConnectX-6 100 Gb NIC\u003cbr /\u003e(12.5 GB/s, Only one physical port enabled)\u003c/td\u003e\n \u003c/tr\u003e\n \u003c/tbody\u003e\n \u003c/table\u003e\n\n \u003c!-- lstopo --no-legend --of svg \u003e cpu.svg --\u003e\n \u003cp\u003eLayout of Older Cluster:\u003c/p\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/setup/setup1.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eOlder Cluster cpu/pcie layout\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cp\u003eLayout of Modern Cluster:\u003c/p\u003e\n \u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/setup/setup2.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eModern Cluster cpu/pcie layout\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"network-baseline-benchmark\"\u003eNetwork Baseline Benchmark\u003c/h2\u003e\n\n\u003cp\u003eDistributed filesystems are only as fast as their network, which often becomes the primary bottleneck depending on the workload, as shown in \u003ca href=\"/blog/2025/3FS-Performance-Journal-2/#first-workload-training-job\"\u003emy measurements in the previous post\u003c/a\u003e.\u003c/p\u003e\n\n\u003cp\u003eSince 3FS uses InfiniBand for data transfer, we first measure raw network performance using the \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_send\u003c/code\u003e, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_read\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_write\u003c/code\u003e benchmarks. These tests show us two things: how close we can get to the theoretical 12.5 GB/s (100 Gbps) limit, and how latency changes with different message sizes\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI will be profiling actual 3FS network traffic to observe what message sizes are used and how they map to these latency measurements in a later post\u003c/span\u003e.\u003c/p\u003e\n\n\u003cp\u003eThe graph plots three key variables:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eMessage Size (Z-axis):\u003c/strong\u003e On a logarithmic scale, showing packet sizes from bytes to 10 megabytes\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eThroughput (Y-axis):\u003c/strong\u003e Data transfer rate in GB/s, with color mapping from blue (low) to red (high)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eLatency (X-axis):\u003c/strong\u003e Transfer completion time in microseconds\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand for instructions on how to interact with the graph\u003c/summary\u003e\n\n \u003cp\u003eThe results of the \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_read_bw\u003c/code\u003e benchmark are plotted in the interactive 3D graph below. You can click and drag to rotate the graph, and hovering over any data point will display its precise values.\u003c/p\u003e\n\n \u003cp\u003eThe \u003cstrong\u003eTest Type\u003c/strong\u003e menu allows you to switch between different benchmark results (\u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_write\u003c/code\u003e and \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eib_send\u003c/code\u003e). The \u003cstrong\u003eView Mode\u003c/strong\u003e can be changed to 2D, which helps observe latency variations more clearly.\u003c/p\u003e\n\n\u003c/details\u003e\n\n\u003c!-- ib-benchmark.html --\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/ib_benchmark.css\" /\u003e\n\n\u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.27.1/plotly.min.js\"\u003e\u003c/script\u003e\n\n\u003cscript src=\"/assets/js/ib_benchmark.js\"\u003e\u003c/script\u003e\n\n\u003cdiv class=\"ib-benchmark-container\" id=\"ib-benchmark-container-nvme_ib_unidirectional\" data-path=\"/assets/images/posts/2025-03-13/ib/ib_benchmark_unidirectional.json\"\u003e\n \u003ch2\u003eIB benchmark unidirectional\u003c/h2\u003e\n \n \u003cdiv class=\"ib-controls\"\u003e\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"testType-nvme_ib_unidirectional\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-nvme_ib_unidirectional\"\u003e\n \u003coption value=\"send_bw\"\u003eSend Bandwidth\u003c/option\u003e\n \u003coption value=\"send_lat\"\u003eSend Latency\u003c/option\u003e\n \u003coption value=\"read_bw\" selected=\"\"\u003eRead Bandwidth\u003c/option\u003e\n \u003coption value=\"read_lat\"\u003eRead Latency\u003c/option\u003e\n \u003coption value=\"write_bw\"\u003eWrite Bandwidth\u003c/option\u003e\n \u003coption value=\"write_lat\"\u003eWrite Latency\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"viewMode-nvme_ib_unidirectional\"\u003eView Mode\u003c/label\u003e\n \u003cselect id=\"viewMode-nvme_ib_unidirectional\"\u003e\n \u003coption value=\"3d\" selected=\"\"\u003e3D Graph\u003c/option\u003e\n \u003coption value=\"2d\"\u003e2D Graph\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"ib-plot-nvme_ib_unidirectional\" class=\"ib-plot-container ib-lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Panels indicator --\u003e\n \u003cdiv id=\"ib-panels-indicator-nvme_ib_unidirectional\" class=\"ib-panels-indicator\" style=\"display: none;\"\u003e\n \u003cspan\u003eActive Panels: \u003cspan id=\"ib-panel-count-nvme_ib_unidirectional\"\u003e0\u003c/span\u003e\u003c/span\u003e\n \u003cbutton id=\"ib-arrange-btn-nvme_ib_unidirectional\" class=\"ib-action-button\"\u003eArrange\u003c/button\u003e\n \u003cbutton id=\"ib-close-all-btn-nvme_ib_unidirectional\" class=\"ib-action-button\"\u003eClose All\u003c/button\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"ib-benchmark-note\"\u003e\n \u003cp\u003eIB Benchmark on unidirectional throughput/latency\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.ib-benchmark-container');\n const id = container.id.replace('ib-benchmark-container-', '');\n const plotEl = document.getElementById('ib-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadIBBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('ib-plot-nvme_ib_unidirectional');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadIBBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForIBBenchmarkJs() {\n if (typeof initInfiniBandPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('ib-benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n const processedData = processBenchmarkData(data);\n const plotId = 'ib-plot-' + id;\n const testType = document.getElementById('testType-' + id).value;\n const viewMode = document.getElementById('viewMode-' + id).value;\n \n // Use our new function to initialize the plot\n window['ibPlot_' + id] = initInfiniBandPlot(plotId, processedData, {\n defaultTest: testType || 'send_bw',\n viewMode: viewMode || '3d'\n });\n \n document.getElementById(plotId).classList.remove('ib-lazy-load');\n \n // Setup event listeners for controls\n document.getElementById('testType-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setTestType) {\n plotObj.setTestType(e.target.value);\n }\n });\n \n document.getElementById('viewMode-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setViewMode) {\n plotObj.setViewMode(e.target.value);\n }\n });\n })\n .catch(error =\u003e {\n console.error('Error loading InfiniBand benchmark data:', error);\n document.getElementById('ib-plot-' + id).innerHTML = \n '\u003cdiv class=\"ib-error\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForIBBenchmarkJs(), 100);\n }\n }\n \n waitForIBBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eKey observations from the throughput graph:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAll three operations (read, write, send) peak at ~11.5 GB/s (92% of theoretical) at 4K-8K message sizes\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eSurprisingly, the send operation (two-sided) achieves the same bandwidth as one-sided RDMA operations. This is unexpected given the additional coordination overhead\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eTo achieve meaningful throughput (\u0026gt;10 GB/s), you need at least 4KB messages\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eSwitching to the latency graph (Read Bandwidth -\u0026gt; Read Latency) reveals additional insights:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAt the same 4K message sizes, latency drops significantly to ~5μs when operating at ~1 GB/s\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThere’s some queuing going on? I’m not sure for this reason\u003c/span\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eSwitching to 2D version of latency graph (Read Bandwidth -\u0026gt; Read Latency, 3D Graph -\u0026gt; 2D Graph):\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eTwo distinct latency regions emerge: a gentle increase from 5μs to 10μs (2 bytes to 64KB), then an almost linear scale beyond 64KBs\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis is also true when we take a look at when the NIC is at full throughput. This makes the performance very predictable, which makes understanding network bottlenecks easier\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eLatency variance remains stable across most message sizes (p50, p90, p99 are tightly grouped)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eSince NICs support bidirectional communication, we also need to measure performance when traffic flows in both directions simultaneously:\u003c/p\u003e\n\n\u003c!-- ib-benchmark.html --\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/ib_benchmark.css\" /\u003e\n\n\u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.27.1/plotly.min.js\"\u003e\u003c/script\u003e\n\n\u003cscript src=\"/assets/js/ib_benchmark.js\"\u003e\u003c/script\u003e\n\n\u003cdiv class=\"ib-benchmark-container\" id=\"ib-benchmark-container-nvme_ib_bidirectional\" data-path=\"/assets/images/posts/2025-03-13/ib/ib_benchmark_bidirectional.json\"\u003e\n \u003ch2\u003eIB benchmark bidirectional\u003c/h2\u003e\n \n \u003cdiv class=\"ib-controls\"\u003e\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"testType-nvme_ib_bidirectional\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-nvme_ib_bidirectional\"\u003e\n \u003coption value=\"send_bw\"\u003eSend Bandwidth\u003c/option\u003e\n \u003coption value=\"send_lat\"\u003eSend Latency\u003c/option\u003e\n \u003coption value=\"read_bw\" selected=\"\"\u003eRead Bandwidth\u003c/option\u003e\n \u003coption value=\"read_lat\"\u003eRead Latency\u003c/option\u003e\n \u003coption value=\"write_bw\"\u003eWrite Bandwidth\u003c/option\u003e\n \u003coption value=\"write_lat\"\u003eWrite Latency\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"viewMode-nvme_ib_bidirectional\"\u003eView Mode\u003c/label\u003e\n \u003cselect id=\"viewMode-nvme_ib_bidirectional\"\u003e\n \u003coption value=\"3d\" selected=\"\"\u003e3D Graph\u003c/option\u003e\n \u003coption value=\"2d\"\u003e2D Graph\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"ib-plot-nvme_ib_bidirectional\" class=\"ib-plot-container ib-lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Panels indicator --\u003e\n \u003cdiv id=\"ib-panels-indicator-nvme_ib_bidirectional\" class=\"ib-panels-indicator\" style=\"display: none;\"\u003e\n \u003cspan\u003eActive Panels: \u003cspan id=\"ib-panel-count-nvme_ib_bidirectional\"\u003e0\u003c/span\u003e\u003c/span\u003e\n \u003cbutton id=\"ib-arrange-btn-nvme_ib_bidirectional\" class=\"ib-action-button\"\u003eArrange\u003c/button\u003e\n \u003cbutton id=\"ib-close-all-btn-nvme_ib_bidirectional\" class=\"ib-action-button\"\u003eClose All\u003c/button\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"ib-benchmark-note\"\u003e\n \u003cp\u003eIB Benchmark on bidirectional throughput/latency\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.ib-benchmark-container');\n const id = container.id.replace('ib-benchmark-container-', '');\n const plotEl = document.getElementById('ib-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadIBBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('ib-plot-nvme_ib_bidirectional');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadIBBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForIBBenchmarkJs() {\n if (typeof initInfiniBandPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('ib-benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n const processedData = processBenchmarkData(data);\n const plotId = 'ib-plot-' + id;\n const testType = document.getElementById('testType-' + id).value;\n const viewMode = document.getElementById('viewMode-' + id).value;\n \n // Use our new function to initialize the plot\n window['ibPlot_' + id] = initInfiniBandPlot(plotId, processedData, {\n defaultTest: testType || 'send_bw',\n viewMode: viewMode || '3d'\n });\n \n document.getElementById(plotId).classList.remove('ib-lazy-load');\n \n // Setup event listeners for controls\n document.getElementById('testType-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setTestType) {\n plotObj.setTestType(e.target.value);\n }\n });\n \n document.getElementById('viewMode-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setViewMode) {\n plotObj.setViewMode(e.target.value);\n }\n });\n })\n .catch(error =\u003e {\n console.error('Error loading InfiniBand benchmark data:', error);\n document.getElementById('ib-plot-' + id).innerHTML = \n '\u003cdiv class=\"ib-error\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForIBBenchmarkJs(), 100);\n }\n }\n \n waitForIBBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe bidirectional results show similarities!\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAt 4K-8K message sizes, we achieve double the throughput while latency drops from 30-60μs to 15-30μs\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis counterintuitive result likely occurs because each direction gets dedicated hardware resources, allowing better pipeline utilization\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eCombined bandwidth reaches ~23 GB/s (~92% of theoretical 25 GB/s)\u003c/li\u003e\n \u003cli\u003eLatencies remain consistent with unidirectional measurements\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThese measurements give us concrete expectations for 3FS operations. For example, when 3FS performs a 3-node write (1KB from 3 storage nodes), the network alone will consume 3-10μs. Any latency above this represents other software/hardware overhead – chunk management, thread contention, or disk I/O.\u003c/p\u003e\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eComparison to NCCL all_reduce_perf (for fun)\u003c/summary\u003e\n\n \u003cp\u003eNCCL is the standard framework for GPU-to-GPU communication in machine learning clusters. Since GPUs also use InfiniBand for inter-node communication, I wanted to see if the same performance patterns emerge.\u003c/p\u003e\n\n \u003cp\u003eThis test uses a 2-node cluster with 8x400Gbps InfiniBand NICs (~400GB/s total), typical for modern GPU clusters like 8xH100 setups.\u003c/p\u003e\n\n \u003c!-- ib-benchmark.html --\u003e\n\n \u003clink rel=\"stylesheet\" href=\"/assets/css/ib_benchmark.css\" /\u003e\n\n \u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.27.1/plotly.min.js\"\u003e\u003c/script\u003e\n\n \u003cscript src=\"/assets/js/ib_benchmark.js\"\u003e\u003c/script\u003e\n\n \u003cdiv class=\"ib-benchmark-container\" id=\"ib-benchmark-container-_ib_bidirectional\" data-path=\"/assets/images/posts/2025-03-13/ib/nccl.json\"\u003e\n \u003ch2\u003eNCCL all_reduce_perf\u003c/h2\u003e\n \n \u003cdiv class=\"ib-controls\"\u003e\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"testType-_ib_bidirectional\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-_ib_bidirectional\"\u003e\n \u003coption value=\"send_bw\"\u003eSend Bandwidth\u003c/option\u003e\n \u003coption value=\"send_lat\"\u003eSend Latency\u003c/option\u003e\n \u003coption value=\"read_bw\"\u003eRead Bandwidth\u003c/option\u003e\n \u003coption value=\"read_lat\"\u003eRead Latency\u003c/option\u003e\n \u003coption value=\"write_bw\"\u003eWrite Bandwidth\u003c/option\u003e\n \u003coption value=\"write_lat\"\u003eWrite Latency\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"ib-control-group\"\u003e\n \u003clabel for=\"viewMode-_ib_bidirectional\"\u003eView Mode\u003c/label\u003e\n \u003cselect id=\"viewMode-_ib_bidirectional\"\u003e\n \u003coption value=\"3d\" selected=\"\"\u003e3D Graph\u003c/option\u003e\n \u003coption value=\"2d\"\u003e2D Graph\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"ib-plot-_ib_bidirectional\" class=\"ib-plot-container ib-lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Panels indicator --\u003e\n \u003cdiv id=\"ib-panels-indicator-_ib_bidirectional\" class=\"ib-panels-indicator\" style=\"display: none;\"\u003e\n \u003cspan\u003eActive Panels: \u003cspan id=\"ib-panel-count-_ib_bidirectional\"\u003e0\u003c/span\u003e\u003c/span\u003e\n \u003cbutton id=\"ib-arrange-btn-_ib_bidirectional\" class=\"ib-action-button\"\u003eArrange\u003c/button\u003e\n \u003cbutton id=\"ib-close-all-btn-_ib_bidirectional\" class=\"ib-action-button\"\u003eClose All\u003c/button\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"ib-benchmark-note\"\u003e\n \u003cp\u003eall_reduce_perf\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.ib-benchmark-container');\n const id = container.id.replace('ib-benchmark-container-', '');\n const plotEl = document.getElementById('ib-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadIBBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('ib-plot-_ib_bidirectional');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadIBBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForIBBenchmarkJs() {\n if (typeof initInfiniBandPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('ib-benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n const processedData = processBenchmarkData(data);\n const plotId = 'ib-plot-' + id;\n const testType = document.getElementById('testType-' + id).value;\n const viewMode = document.getElementById('viewMode-' + id).value;\n \n // Use our new function to initialize the plot\n window['ibPlot_' + id] = initInfiniBandPlot(plotId, processedData, {\n defaultTest: testType || 'send_bw',\n viewMode: viewMode || '3d'\n });\n \n document.getElementById(plotId).classList.remove('ib-lazy-load');\n \n // Setup event listeners for controls\n document.getElementById('testType-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setTestType) {\n plotObj.setTestType(e.target.value);\n }\n });\n \n document.getElementById('viewMode-' + id).addEventListener('change', function(e) {\n const plotObj = window['ibPlot_' + id];\n if (plotObj \u0026\u0026 plotObj.setViewMode) {\n plotObj.setViewMode(e.target.value);\n }\n });\n })\n .catch(error =\u003e {\n console.error('Error loading InfiniBand benchmark data:', error);\n document.getElementById('ib-plot-' + id).innerHTML = \n '\u003cdiv class=\"ib-error\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForIBBenchmarkJs(), 100);\n }\n }\n \n waitForIBBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n \u003cp\u003eThe bandwidth pattern is similar (slow climb then rapid rise), but peak performance hits at ~512MB messages instead of 8KB\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eLikely due to multiple NICs and the collective communication overhead of all_reduce operations\u003c/span\u003e. At the same 8KB message size where our InfiniBand tests peaked, NCCL only achieves ~0.24 GB/s @ ~20us.\u003c/p\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"storage-baseline-benchmark\"\u003eStorage Baseline Benchmark\u003c/h2\u003e\n\n\u003cp\u003e\u003ca href=\"https://fio.readthedocs.io/en/latest/fio_doc.html\"\u003eFIO\u003c/a\u003e is the standard tool for storage benchmarking on Linux, so I’ll be using that in the next section. As a heads up, the 3FS authors conveniently provide a \u003ca href=\"https://github.com/deepseek-ai/3FS/tree/8c9883c27f50da8d1df8ff0b952483d21cdf1792/benchmarks/fio_usrbio\"\u003ecustom FIO engine\u003c/a\u003e specifically for benchmarking their filesystem\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis wasn’t in the original release – they added it after I started this analysis and I would have spent quie a bit of time writing it\u003c/span\u003e which we can compare to!\u003c/p\u003e\n\n\u003ch3 id=\"local-storage-performance\"\u003eLocal Storage Performance\u003c/h3\u003e\n\n\u003cp\u003eBefore measuring 3FS, we need baseline numbers for our SSDs. The following benchmarks show how bandwidth and latency change as we vary two key parameters:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eI/O depth\u003c/strong\u003e: How many operations we submit before waiting for completion (think of it as the queue length)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eJob count\u003c/strong\u003e: How many parallel processes are hammering the storage simultaneously\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThese SSD numbers will become our reference point\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eFor example, with a replication factor of 3, we might see 3x higher read throughput or 3x higher write latency, but this might not be the case!\u003c/span\u003e – when 3FS shows higher latency or lower throughput, we can quantify exactly how much overhead the distributed layer adds.\u003c/p\u003e\n\n\u003cp\u003eI’ll be benchmarking local ssd with io_uring, then 3fs with io_uring and then 3fs with its own custom iouring interface\u003c/p\u003e\n\n\u003cp\u003eI configured 3FS with a replication factor of 3.\u003c/p\u003e\n\n\u003ch4 id=\"hardware-vendor-specifications\"\u003eHardware Vendor Specifications\u003c/h4\u003e\n\n\u003cp\u003eBefore examining our benchmark results, let’s establish the theoretical performance limits according to hardware vendor specifications. These numbers represent the maximum performance we could theoretically achieve under ideal conditions:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Performance Metric\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Random Read\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Sequential Read\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Random Write\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Sequential Write\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSATA SSD\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e276 MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e450 MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row0-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e380 MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row0-col4\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e72 MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNVMe\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.77 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e6.2 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row1-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0.4 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Performance Metric,Random Read,Sequential Read,Random Write,Sequential Write-row1-col4\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.3 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThese theoretical limits come from the \u003ca href=\"https://servak.com.ua/image/manual/SSD/SSD_240GB_2.5_6G_INTEL_DC_S3520_SERIES_SATA_Quick_Specs_Servak_2.pdf\"\u003eIntel DC S3520 SATA\u003c/a\u003e and \u003ca href=\"https://dl.dell.com/manuals/all-products/esuprt_data_center_infra_int/esuprt_data_center_infra_storage_adapters/dell-poweredge-exp-fsh-nvme-pcie-ssd_users-guide7_en-us.pdf\"\u003eDell Enterprise NVMe\u003c/a\u003e specification sheets. In practice, our benchmarks will likely fall short of these numbers due to filesystem overhead, driver limitations, and real-world I/O patterns.\u003c/p\u003e\n\n\u003cp\u003eAlso, the dramatic performance difference between SATA and NVMe storage is pretty immediate. NVMe provides roughly 10-15x higher throughput for most operations and this difference may impact how 3FS performs.\u003c/p\u003e\n\n\u003ch1 id=\"benchmarking-for-older-cluster\"\u003eBenchmarking for Older Cluster\u003c/h1\u003e\n\n\u003ch2 id=\"local-fio-results\"\u003eLocal FIO results\u003c/h2\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand for instructions on how to interact with the graph\u003c/summary\u003e\n\n \u003cp\u003e\u003cstrong\u003eControls:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eTest Type menu\u003c/strong\u003e: Switch between Random Read, Sequential Read, Random Write, and Sequential Write\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eMetric menu\u003c/strong\u003e: Change between Bandwidth, IOPS, and various Latency measurements\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cp\u003e\u003cstrong\u003e3D Navigation:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eClick and drag\u003c/strong\u003e: Rotate the view\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eScroll wheel\u003c/strong\u003e: Zoom in/out\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eHover\u003c/strong\u003e: See exact values for any data point\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eDouble-click\u003c/strong\u003e: Reset to default view\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cp\u003e\u003cstrong\u003eAxes:\u003c/strong\u003e\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eX-axis\u003c/strong\u003e: IO Depth (1 to 128)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eY-axis\u003c/strong\u003e: Number of Jobs (1 to 128)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eColor\u003c/strong\u003e: The selected metric value (blue = low, red = high)\u003c/li\u003e\n \u003c/ul\u003e\n\n\u003c/details\u003e\n\n\u003ch3 id=\"scaling-block-size-for-local-ssd\"\u003eScaling block size for local SSD\u003c/h3\u003e\n\n\u003cp\u003eThe first benchmark uses the older cluster to establish our local SSD baseline. I’m testing how performance changes with different block sizes (4K, 64K, 1MB, 4MB) to understand the storage characteristics of a SATA SSD. The local ssd was configured with xfs filesystem.\u003c/p\u003e\n\n\u003cp\u003eThis is a lot of data. Feel free to jump between the interactive graphs and the \u003ca href=\"#storage-performance-analysis-for-local-ssd\"\u003eperformance analysis\u003c/a\u003e to explore the patterns.\u003c/p\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/benchmark.css\" /\u003e\n\n\u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.27.1/plotly.min.js\"\u003e\u003c/script\u003e\n\n\u003cscript src=\"/assets/js/benchmark.js\"\u003e\u003c/script\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_iouring_4k\" data-path=\"/assets/images/posts/2025-03-13/fio/4k_ssd_xfs_iouring_xl170_1.json\"\u003e\n \u003ch2\u003e4K Block Size - SSD XFS with IO_URING (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_iouring_4k\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_iouring_4k\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_iouring_4k\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_iouring_4k\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_iouring_4k\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_iouring_4k\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_iouring_4k\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_iouring_4k\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_iouring_4k\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_iouring_4k\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_iouring_4k\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_iouring_4k\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003eSmall block (4K) performance using SSD with XFS filesystem and IO_URING driver on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_iouring_4k');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/benchmark.css\" /\u003e\n\n\u003cscript src=\"https://cdnjs.cloudflare.com/ajax/libs/plotly.js/2.27.1/plotly.min.js\"\u003e\u003c/script\u003e\n\n\u003cscript src=\"/assets/js/benchmark.js\"\u003e\u003c/script\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_iouring_64k\" data-path=\"/assets/images/posts/2025-03-13/fio/64k_ssd_xfs_iouring_xl170_1.json\"\u003e\n \u003ch2\u003e64k Block Size - SSD XFS with IO_URING (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_iouring_64k\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_iouring_64k\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_iouring_64k\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_iouring_64k\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_iouring_64k\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_iouring_64k\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_iouring_64k\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_iouring_64k\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_iouring_64k\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_iouring_64k\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_iouring_64k\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_iouring_64k\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003eSmall block (64k) performance using SSD with XFS filesystem and IO_URING driver on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_iouring_64k');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_iouring_1m\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_ssd_xfs_iouring_xl170_1.json\"\u003e\n \u003ch2\u003e1M Block Size - SSD XFS with IO_URING (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_iouring_1m\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_iouring_1m\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_iouring_1m\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_iouring_1m\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_iouring_1m\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_iouring_1m\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_iouring_1m\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_iouring_1m\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_iouring_1m\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_iouring_1m\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_iouring_1m\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_iouring_1m\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance characteristics of SSD with XFS filesystem using IO_URING driver with 1M block size on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_iouring_1m');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_iouring_4m\" data-path=\"/assets/images/posts/2025-03-13/fio/4m_ssd_xfs_iouring_xl170_1.json\"\u003e\n \u003ch2\u003e4m Block Size - SSD XFS with IO_URING (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_iouring_4m\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_iouring_4m\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_iouring_4m\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_iouring_4m\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_iouring_4m\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_iouring_4m\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_iouring_4m\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_iouring_4m\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_iouring_4m\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_iouring_4m\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_iouring_4m\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_iouring_4m\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance characteristics of SSD with XFS filesystem using IO_URING driver with 4m block size on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_iouring_4m');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003ch3 id=\"storage-performance-analysis-for-local-ssd\"\u003eStorage Performance Analysis for Local SSD\u003c/h3\u003e\n\n\u003cp\u003eLet’s examine how performance changes across different block sizes by looking at a specific configuration point: various IO depths at 1 job\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWhy 1 job? This removes one variable from our analysis, allowing us to focus on how IO depth affects performance. We’ll explore job scaling separately\u003c/span\u003e.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/throughput_versus_latency_explain.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThis graph reveals the classic throughput versus latency tradeoff for our SATA SSD\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThese plots are fundamental to understanding storage performance - they show exactly when a system hits diminishing returns\u003c/span\u003e. The Y-axis shows throughput (higher is better), while the X-axis shows latency (lower is better). Each colored line represents a different block size, with dots marking increasing IO depths.\u003c/p\u003e\n\n\u003cp\u003eFirst, let’s examine each axis independently:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eY-axis (Throughput): 64K block sizes achieve the highest peak at 400 MB/s, while other sizes fall short: 4K reaches 250 MB/s, 1M hits 325 MB/s, and 4M peaks at 350 MB/s\u003c/li\u003e\n \u003cli\u003eX-axis (Latency): Large block sizes (1M and 4M) show dramatically higher latency (80ms+) compared to smaller blocks size (4K and 64K)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe cool thing about throughput versus latency graphs is that there’s a knee point – where throughput stops increasing but latency continues climbing\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eCertain systems even decrease throughput after this point as they may need to do additional work to manage work items\u003c/span\u003e. For 64K blocks, this occurs around IO depth 16-32, where we achieve ~400 MB/s at \u0026lt; 10ms.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/knee_point.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eKnee point for throughput versus latency graph\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view throughput versus latency graphs for other workloads\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/read_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential reads\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randwrite_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/write_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view throughput versus latency graphs scaling num jobs for random reads\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randread_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 1 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randread_throughput_vs_latency_all_depths_2jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 2 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randread_throughput_vs_latency_all_depths_4jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 4 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randread_throughput_vs_latency_all_depths_8jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 8 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cp\u003eThese measurements reveal something frustrating, but also quite interesting: there’s no universal sweet spot. What works best depends entirely on whether you care more about latency or throughput and then that depends on what your workload looks like.\u003c/p\u003e\n\n\u003cp\u003eCouple of interesting things to observe:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eLatency increases in different amounts as block sizes increase\u003c/li\u003e\n \u003cli\u003eLatency doubles as numjobs increases\u003c/li\u003e\n \u003cli\u003eThere’s not one block size that’s optimal for bandwidth for a workload. For random reads, it’s 64k. For sequential reads, it’s 4k.\u003c/li\u003e\n \u003cli\u003eFor lowest latency, use smaller block size, but the SSD most likely won’t fully saturate its bandwidth.\u003c/li\u003e\n \u003cli\u003eWrites have different knee points than reads (for example, 4k sequential writes knee point caps at 150 MB/s while 4k sequential reads cap at 300 MB/s)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eWith these patterns established, let’s examine the NVMe fio benchmarks to see whether these observations hold true or if new patterns emerge.\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eDouble checking if libaio has any difference\u003c/summary\u003e\n\n \u003cp\u003eThe performance shown in the graphs above represent io_uring. Are there any differences with another async io library (libaio?)\u003c/p\u003e\n\n \u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n \u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n \u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n \u003cdiv id=\"sata-ssd-performance-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003ctable id=\"sata-ssd-performance-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Workload\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Configuration\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Bandwidth (MB/s)\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Avg Latency (ms)\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n P99 Latency (ms)\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e242.6\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.01\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.29\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e329.8\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e377.50\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row1-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e484.44\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e240.7\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.02\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row2-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.32\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e329.9\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row3-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e378.01\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row3-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e488.64\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e153.6\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row4-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.18\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row4-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5.47\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e159.6\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row5-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e780.23\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row5-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e977.27\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e151.3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row6-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.23\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row6-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5.55\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRandom Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e153.3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row7-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e855.61\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row7-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e935.33\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row8-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row8-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row8-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e410.7\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row8-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.22\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row8-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.97\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row9-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row9-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row9-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e276.7\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row9-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e460.66\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row9-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e488.64\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row10-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row10-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row10-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e402.1\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row10-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.22\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row10-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.01\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row11-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row11-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row11-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e270.3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row11-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e467.39\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row11-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e497.03\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row12-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row12-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row12-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e148.2\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row12-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.30\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row12-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5.47\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row13-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row13-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M iouring\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row13-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e143.7\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row13-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e866.88\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row13-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e935.33\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row14-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row14-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4K libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row14-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e147.7\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row14-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.29\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row14-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e5.44\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row15-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row15-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1M libaio\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row15-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e145.6\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row15-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e855.61\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"sata-ssd-performance-table-row15-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e960.50\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n \u003c!-- benchmark.html --\u003e\n\n \u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_libaio_4k\" data-path=\"/assets/images/posts/2025-03-13/fio/4k_ssd_xfs_libaio_xl170_1.json\"\u003e\n \u003ch2\u003e4K Block Size - SSD XFS with LIBAIO (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_libaio_4k\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_libaio_4k\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_libaio_4k\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_libaio_4k\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_libaio_4k\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_libaio_4k\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_libaio_4k\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_libaio_4k\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_libaio_4k\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_libaio_4k\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_libaio_4k\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_libaio_4k\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance comparison of SSD with XFS filesystem using LIBAIO driver with 4k block size on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_libaio_4k');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n \u003c!-- benchmark.html --\u003e\n\n \u003cdiv class=\"benchmark-container\" id=\"benchmark-container-ssd_xfs_libaio_1m\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_ssd_xfs_libaio_xl170_1.json\"\u003e\n \u003ch2\u003e1M Block Size - SSD XFS with LIBAIO (Older)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-ssd_xfs_libaio_1m\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-ssd_xfs_libaio_1m\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-ssd_xfs_libaio_1m\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-ssd_xfs_libaio_1m\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-ssd_xfs_libaio_1m\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-ssd_xfs_libaio_1m\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-ssd_xfs_libaio_1m\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-ssd_xfs_libaio_1m\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-ssd_xfs_libaio_1m\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-ssd_xfs_libaio_1m\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-ssd_xfs_libaio_1m\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-ssd_xfs_libaio_1m\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance comparison of SSD with XFS filesystem using LIBAIO driver with 1M block size on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-ssd_xfs_libaio_1m');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n \u003cp\u003eNothing sizable of difference.\u003c/p\u003e\n\n\u003c/details\u003e\n\n\u003ch1 id=\"benchmarking-for-modern-cluster\"\u003eBenchmarking for Modern Cluster\u003c/h1\u003e\n\n\u003ch2 id=\"local-fio-results-1\"\u003eLocal FIO results\u003c/h2\u003e\n\n\u003ch3 id=\"scaling-block-size-for-local-nvme\"\u003eScaling block size for local NVMe\u003c/h3\u003e\n\n\u003cp\u003eAgain, feel free to jump between the interactive graphs and the \u003ca href=\"#storage-performance-analysis-for-local-nvme\"\u003eperformance analysis\u003c/a\u003e to explore the patterns.\u003c/p\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-4k_nvme_xfs_iouring_r650\" data-path=\"/assets/images/posts/2025-03-13/fio/4k_nvme_xfs_iouring_r650_1.json\"\u003e\n \u003ch2\u003e4k Block Size - NVME XFS with IO_URING (Modern)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-4k_nvme_xfs_iouring_r650\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-4k_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-4k_nvme_xfs_iouring_r650\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-4k_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-4k_nvme_xfs_iouring_r650\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-4k_nvme_xfs_iouring_r650\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-4k_nvme_xfs_iouring_r650\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-4k_nvme_xfs_iouring_r650\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-4k_nvme_xfs_iouring_r650\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-4k_nvme_xfs_iouring_r650\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-4k_nvme_xfs_iouring_r650\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-4k_nvme_xfs_iouring_r650\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of NVME with XFS filesystem using IO_URING driver on modern cluster with 4k block size.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-4k_nvme_xfs_iouring_r650');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-64k_nvme_xfs_iouring_r650\" data-path=\"/assets/images/posts/2025-03-13/fio/64k_nvme_xfs_iouring_r650_1.json\"\u003e\n \u003ch2\u003e64k Block Size - NVME XFS with IO_URING (Modern)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-64k_nvme_xfs_iouring_r650\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-64k_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-64k_nvme_xfs_iouring_r650\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-64k_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-64k_nvme_xfs_iouring_r650\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-64k_nvme_xfs_iouring_r650\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-64k_nvme_xfs_iouring_r650\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-64k_nvme_xfs_iouring_r650\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-64k_nvme_xfs_iouring_r650\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-64k_nvme_xfs_iouring_r650\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-64k_nvme_xfs_iouring_r650\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-64k_nvme_xfs_iouring_r650\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of NVME with XFS filesystem using IO_URING driver on modern cluster with 64k block size.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-64k_nvme_xfs_iouring_r650');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-1m_nvme_xfs_iouring_r650\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_nvme_xfs_iouring_r650_1.json\"\u003e\n \u003ch2\u003e1M Block Size - NVME XFS with IO_URING (Modern)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-1m_nvme_xfs_iouring_r650\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-1m_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-1m_nvme_xfs_iouring_r650\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-1m_nvme_xfs_iouring_r650\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-1m_nvme_xfs_iouring_r650\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-1m_nvme_xfs_iouring_r650\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-1m_nvme_xfs_iouring_r650\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-1m_nvme_xfs_iouring_r650\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-1m_nvme_xfs_iouring_r650\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-1m_nvme_xfs_iouring_r650\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-1m_nvme_xfs_iouring_r650\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-1m_nvme_xfs_iouring_r650\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of NVME with XFS filesystem using IO_URING driver on modern cluster with 1M block size.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-1m_nvme_xfs_iouring_r650');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\u003ch3 id=\"storage-performance-analysis-for-local-nvme\"\u003eStorage Performance Analysis for local NVMe\u003c/h3\u003e\n\n\u003cp\u003eLet’s examine how the NVMe drive performs compared to our SATA baseline:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randread_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 110%; margin-left: calc((100% - 110%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads on NVMe\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view SATA SSD comparison graph\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_ssd/randread_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads on SATA SSD\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cp\u003eThe NVMe improvement is dramatic:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eThroughput:\u003c/strong\u003e 10x higher across all block sizes (1 GB/s vs 250 MB/s for 4K, 4 GB/s vs 400 MB/s for 64K)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eLatency:\u003c/strong\u003e Consistently lower, especially for large blocks\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eFor 64K blocks: NVMe stays at ~1ms while SATA climbs to ~20ms - a 20x difference\u003c/span\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eTwo interesting differences from SATA patterns:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e64K and 1M blocks need higher IO depths to hit their knee points, suggesting NVMe controllers require more parallelism for peak performance\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003e3FS may need to be configured with sufficient parallelism to extract maximum NVMe performance\u003c/span\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view throughput versus latency graphs for other workloads\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/read_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential reads\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randwrite_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/write_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view throughput versus latency graphs scaling num jobs for random reads\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randread_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 1 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randread_throughput_vs_latency_all_depths_2jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 2 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randread_throughput_vs_latency_all_depths_4jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 4 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_nvme/randread_throughput_vs_latency_all_depths_8jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads for 8 numjobs\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003cp\u003eSequential reads follow similar patterns to random reads, maintaining a similar high throughput ceiling and low latency.\u003c/p\u003e\n\n\u003cp\u003eWrite performance reveals a different story. Both random and sequential writes drop to ~2 GB/s peak throughput, with knee points occurring at much lower IO depths for 64K and 1M blocks\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis aligns with the vendor specification showing NVMe write performance (2.3 GB/s) is significantly lower than read performance (6.2 GB/s)\u003c/span\u003e.\u003c/p\u003e\n\n\u003cp\u003eThe numjobs scaling patterns mirror what we observed with SATA SSDs: throughput increases with additional parallel jobs, but latency scales proportionally. Doubling jobs roughly doubles latency but provides less than 2x throughput improvement.\u003c/p\u003e\n\n\u003ch2 id=\"predicting-3fs-performance\"\u003ePredicting 3FS Performance\u003c/h2\u003e\n\n\u003cp\u003eBefore diving into actual 3FS benchmarks, let’s make some predictions based on our hardware baseline measurements:\u003c/p\u003e\n\n\u003cp\u003eFor random/sequentials reads, our theoretical ceiling is 18 GB/s as there’s a replication factor of 3 and both random/sequential reads hit 6 GB/s.\u003c/p\u003e\n\n\u003cp\u003eHowever, we’re bound by network bandwidth as it as a theoretical limit of 12.5 GB/s (realistically ~11.5 GB/s from our previous micro-benchmarks).\u003c/p\u003e\n\n\u003cp\u003eLet’s now talk about latency in the worst and best case. We can pull the network and disk latency from the graphs we have, starting with reads.\u003c/p\u003e\n\n\u003cp\u003eIn the average case:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eThe average network latency for 1MB of data is 91us\u003c/li\u003e\n \u003cli\u003eThe average disk latency for sequential/random reads for 1M block size (1 IO depth, 1 job) is 0.48ms\u003c/li\u003e\n \u003cli\u003eSo the the latency we should expect is 0.48ms\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eIn the worse case:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eThe p99 network latency for 1MB of data is 282us\u003c/li\u003e\n \u003cli\u003eThe p99 disk latency for sequential/random reads for 1M block size (128 IO depth, 16 job) is 448ms/420ms\u003c/li\u003e\n \u003cli\u003eSo the the latency we should expect is 448ms\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eWhat we can see is a 100x difference in latency between the average and worse case. Another thing that we can clearly see that the latency is dominated by disk latency.\u003c/p\u003e\n\n\u003cp\u003eMoving on to writes,\u003c/p\u003e\n\n\u003cp\u003eAverage case:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e91us\u003c/li\u003e\n \u003cli\u003e0.46ms (1 IO depth, 1 job)\u003c/li\u003e\n \u003cli\u003eSo, latency combined is 0.46ms * 3 (chained) = 1.38ms\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eP99 case:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e187us\u003c/li\u003e\n \u003cli\u003e892ms (128 IO depth, 16 job)\u003c/li\u003e\n \u003cli\u003eSo, latency combined is 892ms * 3 (chained) = 2.67s\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eWrites can be 2000x+ slower in the worse case. This is due to the multiplicative factor of writes since writes have to go through each node.\u003c/p\u003e\n\n\u003cp\u003eKeeping in this mind, let’s head into the benchmarks:\u003c/p\u003e\n\n\u003ch2 id=\"3fs\"\u003e3FS\u003c/h2\u003e\n\n\u003cp\u003e3FS is benchmarked using two different I/O interfaces: io_uring, the standard Linux asynchronous I/O interface, or USRBIO, a custom FIO engine that integrates directly with 3FS’s I/O queue management system.\u003c/p\u003e\n\n\u003ch3 id=\"io_uring\"\u003eIO_URING\u003c/h3\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-1m_hf3fs_xfs_iouring_r650\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_hf3fs_xfs_iouring_r650_5.json\"\u003e\n \u003ch2\u003e1M Block Size - HF3FS XFS with IO_URING (Modern)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-1m_hf3fs_xfs_iouring_r650\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-1m_hf3fs_xfs_iouring_r650\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-1m_hf3fs_xfs_iouring_r650\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-1m_hf3fs_xfs_iouring_r650\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-1m_hf3fs_xfs_iouring_r650\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-1m_hf3fs_xfs_iouring_r650\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-1m_hf3fs_xfs_iouring_r650\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-1m_hf3fs_xfs_iouring_r650\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-1m_hf3fs_xfs_iouring_r650\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-1m_hf3fs_xfs_iouring_r650\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-1m_hf3fs_xfs_iouring_r650\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-1m_hf3fs_xfs_iouring_r650\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of HF3FS with XFS filesystem using IO_URING driver on modern cluster with 1M block size.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-1m_hf3fs_xfs_iouring_r650');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003ch3 id=\"usrbio\"\u003eUSRBIO\u003c/h3\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-1m_hf3fs_xfs_usrbio_r650\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_hf3fs_xfs_usrbio_r650_5.json\"\u003e\n \u003ch2\u003e1M Block Size - HF3FS XFS with USRBIO (Modern)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-1m_hf3fs_xfs_usrbio_r650\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-1m_hf3fs_xfs_usrbio_r650\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-1m_hf3fs_xfs_usrbio_r650\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-1m_hf3fs_xfs_usrbio_r650\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-1m_hf3fs_xfs_usrbio_r650\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-1m_hf3fs_xfs_usrbio_r650\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-1m_hf3fs_xfs_usrbio_r650\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-1m_hf3fs_xfs_usrbio_r650\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-1m_hf3fs_xfs_usrbio_r650\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-1m_hf3fs_xfs_usrbio_r650\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-1m_hf3fs_xfs_usrbio_r650\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-1m_hf3fs_xfs_usrbio_r650\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of HF3FS with XFS filesystem using USRBIO driver on modern cluster with 1M block size.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-1m_hf3fs_xfs_usrbio_r650');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eOne thing to observe is that for io_uring, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eio_depth\u003c/code\u003e does not affect the performance.\u003c/p\u003e\n\n\u003cp\u003eAgain, here’s the 2D graph. Do note that \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eIO_URING\u003c/code\u003e is that same spot.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_hf3fs/randread_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 110%; margin-left: calc((100% - 110%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random reads on NVMe\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eOne interesting thing to observe is that \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eio_uring\u003c/code\u003e has lower latency at the same throughput as \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eusrbio\u003c/code\u003e.\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand to view throughput versus latency graphs for other workloads\u003c/summary\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_hf3fs/read_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential reads\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_hf3fs/randwrite_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for random writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part3/fio_hf3fs/write_throughput_vs_latency_all_depths_1jobs.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eThroughput versus latency graph for sequential writes\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch2 id=\"does-the-performance-match-the-estimates\"\u003eDoes the performance match the estimates?\u003c/h2\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Metric,Predicted,Actual-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Metric,Predicted,Actual\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Metric\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Predicted\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Actual\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRead Latency (1MB)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0.48ms\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.09ms (127% worse)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRead P99 Latency\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e304ms\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e194ms (36% better)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eRead Bandwidth\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e11.5 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10.3 GB/s (10% worse)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eWrite Latency (1MB)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.38ms\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.55ms (85% worse)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eWrite P99 Latency\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e0.89s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row4-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.1s (24% worse)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eWrite Bandwidth\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.1 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Metric,Predicted,Actual-row5-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.8 GB/s (14% worse)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe 2x latency overhead for reads and writes may be coming from the software side of things\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWe’ll have to dig deeper later to see why\u003c/span\u003e. One interesting thing to see is that P99.9 latency is better for reads because the network bandwidth caps throughput before storage hits worst-case scenarios. What’s nice to see is that the bandiwdth only decreases by 10-15%!\u003c/p\u003e\n\n\u003ch2 id=\"3fs-1\"\u003e3FS\u003c/h2\u003e\n\n\u003cp\u003eNow we examine how 3FS scales with block size and node count on the older cluster (SATA SSDs + 25 Gbps networking).\u003c/p\u003e\n\n\u003ch3 id=\"scaling-block-size-5-nodes\"\u003eScaling block size (5 nodes)\u003c/h3\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-hf3fs_xfs_usrbio_4k\" data-path=\"/assets/images/posts/2025-03-13/fio/4k_hf3fs_xfs_usrbio_xl170_5.json\"\u003e\n \u003ch2\u003e4K Block Size - HF3FS XFS with USRBIO (Older-5-Nodes)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-hf3fs_xfs_usrbio_4k\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-hf3fs_xfs_usrbio_4k\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-hf3fs_xfs_usrbio_4k\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-hf3fs_xfs_usrbio_4k\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-hf3fs_xfs_usrbio_4k\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-hf3fs_xfs_usrbio_4k\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-hf3fs_xfs_usrbio_4k\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-hf3fs_xfs_usrbio_4k\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-hf3fs_xfs_usrbio_4k\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-hf3fs_xfs_usrbio_4k\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-hf3fs_xfs_usrbio_4k\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-hf3fs_xfs_usrbio_4k\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of HF3FS with XFS filesystem using USRBIO driver with 4K block size on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-hf3fs_xfs_usrbio_4k');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-hf3fs_xfs_usrbio_1m_xl170\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_hf3fs_ext4_usrbio_xl170_5.json\"\u003e\n \u003ch2\u003e1M Block Size - HF3FS XFS with USRBIO (Older-5-Nodes)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-hf3fs_xfs_usrbio_1m_xl170\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-hf3fs_xfs_usrbio_1m_xl170\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-hf3fs_xfs_usrbio_1m_xl170\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-hf3fs_xfs_usrbio_1m_xl170\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-hf3fs_xfs_usrbio_1m_xl170\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-hf3fs_xfs_usrbio_1m_xl170\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-hf3fs_xfs_usrbio_1m_xl170\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-hf3fs_xfs_usrbio_1m_xl170\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-hf3fs_xfs_usrbio_1m_xl170\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-hf3fs_xfs_usrbio_1m_xl170\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-hf3fs_xfs_usrbio_1m_xl170\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-hf3fs_xfs_usrbio_1m_xl170\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003eMedium block (1M) performance using HF3FS with XFS filesystem and USRBIO driver on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-hf3fs_xfs_usrbio_1m_xl170');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe 4K block size stays well below the 3.25 GB/s network limit, reaching only 1 GB/s with 4ms latency. The 1M block size hits the network bandwidth ceiling but pays a latency penalty (6ms at 1 IO depth with 8 jobs compared to 4K’s 4ms maximum)\u003c/p\u003e\n\n\u003ch3 id=\"scaling-nodes\"\u003eScaling nodes\u003c/h3\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-hf3fs_xfs_usrbio_1m_xl170_5\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_hf3fs_ext4_usrbio_xl170_5.json\"\u003e\n \u003ch2\u003e1M Block Size - HF3FS XFS with USRBIO (Older-5-Nodes)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-hf3fs_xfs_usrbio_1m_xl170_5\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-hf3fs_xfs_usrbio_1m_xl170_5\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-hf3fs_xfs_usrbio_1m_xl170_5\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-hf3fs_xfs_usrbio_1m_xl170_5\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-hf3fs_xfs_usrbio_1m_xl170_5\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-hf3fs_xfs_usrbio_1m_xl170_5\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-hf3fs_xfs_usrbio_1m_xl170_5\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003eMedium block (1M) performance using HF3FS with XFS filesystem and USRBIO driver on older cluster.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-hf3fs_xfs_usrbio_1m_xl170_5');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-hf3fs_xfs_iouring_1m_xl170_18\" data-path=\"/assets/images/posts/2025-03-13/fio/1m_hf3fs_xfs_usrbio_xl170_18.json\"\u003e\n \u003ch2\u003e1M Block Size - HF3FS XFS with IO_URING (Older-18-Nodes)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-hf3fs_xfs_iouring_1m_xl170_18\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-hf3fs_xfs_iouring_1m_xl170_18\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-hf3fs_xfs_iouring_1m_xl170_18\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-hf3fs_xfs_iouring_1m_xl170_18\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-hf3fs_xfs_iouring_1m_xl170_18\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-hf3fs_xfs_iouring_1m_xl170_18\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-hf3fs_xfs_iouring_1m_xl170_18\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-hf3fs_xfs_iouring_1m_xl170_18\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-hf3fs_xfs_iouring_1m_xl170_18\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-hf3fs_xfs_iouring_1m_xl170_18\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-hf3fs_xfs_iouring_1m_xl170_18\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-hf3fs_xfs_iouring_1m_xl170_18\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003ePerformance of HF3FS with XFS filesystem using IO_URING driver with 1M blocks on 18 node configuration.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-hf3fs_xfs_iouring_1m_xl170_18');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eComparing 5 vs 18 nodes with 1M blocks shows latency increases with cluster size. At 18 nodes, scaling jobs works better than scaling IO depth for latency: 8 jobs/1 IO depth achieves 10ms @ 1.25 GB/s while 1 job/128 IO depth hits 90ms @ 1 GB/s.\u003c/p\u003e\n\n\u003cp\u003eWith 18 nodes at 300 MB/s each, we’d expect 5.4 GB/s total, but the 25 Gbps network caps us at 3.25 GB/s and realistically we get 2.35 GB/s.\u003c/p\u003e\n\n\u003cp\u003eOne thing that is a glaring issue are that after a certain point, the throughput drops rather significantly. The local results hold the bandiwidth. I’m not entirely sure now why that is, but configuration seems to me even more important as the throughput can decrease drasitcally.\u003c/p\u003e\n\n\u003ch3 id=\"watch-out-for-really-large-block-sizes\"\u003eWatch out for really large block sizes\u003c/h3\u003e\n\n\u003c!-- benchmark.html --\u003e\n\n\u003cdiv class=\"benchmark-container\" id=\"benchmark-container-hf3fs_xfs_usrbio_4m_xl170\" data-path=\"/assets/images/posts/2025-03-13/fio/4m_hf3fs_xfs_usrbio_xl170_18.json\"\u003e\n \u003ch2\u003e4M Block Size - HF3FS XFS with USRBIO (Older-18)\u003c/h2\u003e\n \n \u003cdiv class=\"controls\"\u003e\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"testType-hf3fs_xfs_usrbio_4m_xl170\"\u003eTest Type\u003c/label\u003e\n \u003cselect id=\"testType-hf3fs_xfs_usrbio_4m_xl170\"\u003e\n \u003coption value=\"randread\"\u003eRandom Read\u003c/option\u003e\n \u003coption value=\"read\" selected=\"\"\u003eSequential Read\u003c/option\u003e\n \u003coption value=\"randwrite\"\u003eRandom Write\u003c/option\u003e\n \u003coption value=\"write\"\u003eSequential Write\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n\n \u003cdiv class=\"control-group\"\u003e\n \u003clabel for=\"metricType-hf3fs_xfs_usrbio_4m_xl170\"\u003eMetric\u003c/label\u003e\n \u003cselect id=\"metricType-hf3fs_xfs_usrbio_4m_xl170\"\u003e\n \u003coption value=\"bandwidth\" selected=\"\"\u003eBandwidth (GB/s)\u003c/option\u003e\n \u003coption value=\"iops\"\u003eIOPS\u003c/option\u003e\n \u003coption value=\"latency\"\u003eLatency (μs)\u003c/option\u003e\n \u003coption value=\"latency_p50\"\u003eLatency p50 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p90\"\u003eLatency p90 (μs)\u003c/option\u003e\n \u003coption value=\"latency_p99\"\u003eLatency p99 (μs)\u003c/option\u003e\n \u003c/select\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \u003cdiv id=\"benchmark-plot-hf3fs_xfs_usrbio_4m_xl170\" class=\"plot-container lazy-load\"\u003e\u003c/div\u003e\n \n \u003c!-- Draggable panel for latency data --\u003e\n \u003cdiv id=\"latencyPanel-hf3fs_xfs_usrbio_4m_xl170\" class=\"benchmark-draggable-panel\"\u003e\n \u003cdiv id=\"panelHeader-hf3fs_xfs_usrbio_4m_xl170\" class=\"panel-header\"\u003e\n \u003ch3 class=\"panel-title\" id=\"panelTitle-hf3fs_xfs_usrbio_4m_xl170\"\u003eLatency Percentiles\u003c/h3\u003e\n \u003cdiv class=\"panel-controls\"\u003e\n \u003cbutton id=\"collapseBtn-hf3fs_xfs_usrbio_4m_xl170\" class=\"collapse-btn\" title=\"Collapse\"\u003e▲\u003c/button\u003e\n \u003cbutton id=\"closeLatencyBtn-hf3fs_xfs_usrbio_4m_xl170\" class=\"close-btn\" title=\"Close\"\u003e×\u003c/button\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \u003cdiv class=\"panel-content\"\u003e\n \u003cdiv class=\"latency-details\" id=\"latencyDetails-hf3fs_xfs_usrbio_4m_xl170\"\u003e\u003c/div\u003e\n \u003cdiv id=\"latencyPlot-hf3fs_xfs_usrbio_4m_xl170\" class=\"latency-plot-container\"\u003e\u003c/div\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n\n \n \u003cdiv class=\"benchmark-note\"\u003e\n \u003cp\u003eLarge block (4M) performance using HF3FS with XFS filesystem and USRBIO driver on 18 node configuration.\u003c/p\u003e\n \u003c/div\u003e\n \n\n \u003cscript\u003e\n document.addEventListener('DOMContentLoaded', function() {\n // Create intersection observer for lazy loading\n const observer = new IntersectionObserver((entries, observer) =\u003e {\n entries.forEach(entry =\u003e {\n if (entry.isIntersecting) {\n const container = entry.target.closest('.benchmark-container');\n const id = container.id.replace('benchmark-container-', '');\n const plotEl = document.getElementById('benchmark-plot-' + id);\n \n // Check if already loading or loaded to prevent duplicate requests\n if (!plotEl.dataset.loading) {\n plotEl.dataset.loading = 'true';\n \n // Load the benchmark data and initialize the plot\n loadBenchmarkData(id);\n }\n \n // Stop observing once we've started loading\n observer.unobserve(entry.target);\n }\n });\n }, {\n rootMargin: '200px 0px', // Load when within 200px of viewport\n threshold: 0.01\n });\n \n // Start observing the plot container\n const plotContainer = document.getElementById('benchmark-plot-hf3fs_xfs_usrbio_4m_xl170');\n if (plotContainer) {\n observer.observe(plotContainer);\n }\n });\n \n // Function to load benchmark data\n function loadBenchmarkData(id) {\n // Ensure benchmark.js is loaded before initializing\n function waitForBenchmarkJs() {\n if (typeof initBenchmarkPlot === 'function') {\n // Function exists, proceed with initialization\n const container = document.getElementById('benchmark-container-' + id);\n const dataPath = container.getAttribute('data-path');\n \n if (dataPath) {\n // Use fetch API to load JSON data only when needed\n fetch(dataPath)\n .then(response =\u003e {\n if (!response.ok) {\n throw new Error(`HTTP error! Status: ${response.status}`);\n }\n return response.json();\n })\n .then(data =\u003e {\n // Store data and initialize the plot\n window['benchmarkData_' + id] = data;\n initBenchmarkPlot(id);\n document.getElementById('benchmark-plot-' + id).classList.remove('lazy-load');\n })\n .catch(error =\u003e {\n console.error('Error loading benchmark data:', error);\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError loading benchmark data. Check console for details.\u003c/div\u003e';\n });\n \n } else {\n console.error('No data source provided for benchmark visualization');\n document.getElementById('benchmark-plot-' + id).innerHTML = \n '\u003cdiv style=\"padding: 20px; color: red;\"\u003eError: No data source provided for benchmark visualization.\u003c/div\u003e';\n }\n } else {\n // Function not available yet, wait and try again\n setTimeout(() =\u003e waitForBenchmarkJs(), 100);\n }\n }\n \n waitForBenchmarkJs();\n }\n \u003c/script\u003e\n\u003c/div\u003e\n\n\u003cp\u003eFor 4M blocks, 3FS achieves 2.5 GB/s with just 1 IO depth and 8 jobs\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThis approaches 77% of the theoretical 3.25 GB/s network limit.\u003c/span\u003e. As you can see, increasing the number of nodes or the block sizes shifts the graph a little bit.\u003c/p\u003e\n\n\u003ch2 id=\"wrapping-up\"\u003eWrapping up\u003c/h2\u003e\n\n\u003cp\u003eThe microbenchmarks reveal concrete performance characteristics for 3FS across different hardware configurations. We now have baseline numbers showing how 3FS compares to local storage and where the bottlenecks emerge.\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003e3FS adds predictable overhead: ~1ms for reads, ~1.2ms for writes\u003c/li\u003e\n \u003cli\u003eNetwork bandwidth becomes the limiting factor before storage saturation\u003c/li\u003e\n \u003cli\u003ePerformance scales reasonably with both block size and node count\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe next step is testing 3FS with actual workloads to see how much the performance translates to practice. Since 3FS has a relatively generic interface, we can compare with many other systems.\u003c/p\u003e\n\n\u003ch1 id=\"citation\"\u003eCitation\u003c/h1\u003e\n\n\u003cp\u003eTo cite this article:\u003c/p\u003e\n\n\u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e@article{zhu20253fs3,\n title = {Network Storage and Scaling Characteristics of a Distributed Filesystem},\n author = {Zhu, Henry},\n journal = {maknee.github.io},\n year = {2025},\n month = {September},\n url = \"https://maknee.github.io/blog/2025/3FS-Performance-Journal-3/\"\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e","summary":"Series","date_published":"2025-09-16T06:00:00+00:00","date_modified":"2025-09-16T06:00:00+00:00","author":{"name":""},"tags":["3FS"]},{"id":"https://maknee.github.io/blog/2025/Network-And-Storage-Training-Skypilot","url":"https://maknee.github.io/blog/2025/Network-And-Storage-Training-Skypilot/","title":"Network and Storage Benchmarks for LLM Training on the Cloud","content_html":"\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/banner.png\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eAI usage has become universal. Teams everywhere are building RAG, generating embeddings, and training increasingly sophisticated agents.\u003c/p\u003e\n\n\u003cp\u003eMost distributed LLM training guides focus on model architecture and hyperparameters while ignoring a critical bottleneck: infrastructure configuration. Network and storage choices often determine whether training takes hours or days.\u003c/p\u003e\n\n\u003cp\u003eI ran benchmarks finetuning \u003ca href=\"https://huggingface.co/google/gemma-3-12b-it\"\u003eGemma 3 12B\u003c/a\u003e and \u003ca href=\"https://huggingface.co/openai/gpt-oss-120b\"\u003eGPT-OSS-120B\u003c/a\u003e with different storage and network configurations using \u003ca href=\"https://github.com/skypilot-org/skypilot\"\u003eSkyPilot\u003c/a\u003e for infra and \u003ca href=\"https://nebius.com/\"\u003eNebius\u003c/a\u003e for GPUs. The results reveal that InfiniBand networking provides 10x faster training than standard Ethernet, while optimal storage selection can speed up checkpointing by almost 2x. Combined, these infrastructure optimizations deliver 6-7x end-to-end speedup alone.\u003c/p\u003e\n\n\u003ch2 id=\"some-background-on-training-bottlenecks\"\u003eSome background on training bottlenecks\u003c/h2\u003e\n\n\u003cp\u003eHere’s something that surprises most people new to large-scale training: your GPUs are most likely not the limiting factor. Modern accelerators like H200s will happily consume whatever data you can feed them. The real challenge is keeping them fed.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/compute.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eGPU compute scaling vs memory/network bandwidth\n \n (Image source: \u003ca href=\"https://horace.io/brrr_intro.html\" rel=\"external nofollow noopener\" target=\"_blank\"\u003ehorace\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThink of your GPU as an extremely efficient factory. It can process raw materials (your data) at incredible speeds, but it depends entirely on a steady supply chain. Your storage systems hold the raw materials, and the bandwidth between storage and compute acts as the conveyor belt. These days, that conveyor belt has become the constraint.\u003c/p\u003e\n\n\u003cp\u003eWhile GPU compute capability has grown exponentially, memory bandwidth and network speeds have followed a more modest trajectory.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/high_flyer_scaling.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eScaling trends in compute vs bandwidth\n \n (Image source: \u003ca href=\"https://arxiv.org/html/2408.14158v1\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eFire-Flyer AI-HPC: A Cost-Effective Software-Hardware Co-Design for Deep Learning\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003ch2 id=\"the-two-levers-you-control\"\u003eThe two levers you control\u003c/h2\u003e\n\n\u003cp\u003eWhen running distributed training, you have meaningful control over two critical components: storage and networking, especially when running on cloud GPUs.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/components.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThe objective is straightforward: maximize GPU utilization (or in other words, minimize GPU idleness). But achieving this requires understanding how data flows through your training pipeline and where bottlenecks typically emerge.\u003c/p\u003e\n\n\u003ch3 id=\"the-training-data-flow\"\u003eThe training data flow\u003c/h3\u003e\n\n\u003cp\u003eDuring training, data moves through these stages:\u003c/p\u003e\n\u003col\u003e\n \u003cli\u003e\u003cstrong\u003eLoad batches\u003c/strong\u003e from dataset – storage\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eCommunicate gradients\u003c/strong\u003e between nodes – network\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eDump checkpoint\u003c/strong\u003e to save progress – storage\u003c/li\u003e\n\u003c/ol\u003e\n\n\u003cp\u003eIn any of these steps, bottlenecks can emerge. For example, loading datasets from or saving checkpoints to storage might take extraordinarily long and block GPU progress. Or the inter-node network bandwidth might be insufficient for communication operations (to synchronize weights/gradients).\u003c/p\u003e\n\n\u003ch2 id=\"performance-benchmarks\"\u003ePerformance benchmarks\u003c/h2\u003e\n\n\u003cp\u003eI’ll use two concrete examples throughout:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eGoogle \u003ca href=\"https://huggingface.co/google/gemma-3-12b-it\"\u003eGemma 3 12B\u003c/a\u003e on 2 nodes × H100:8 GPUs\u003c/li\u003e\n \u003cli\u003eOpenAI \u003ca href=\"https://huggingface.co/openai/gpt-oss-120b\"\u003eGPT-OSS-120B\u003c/a\u003e on 4 nodes × H200:8 GPUs\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eI ran some experiments on Nebius, a golden GPU provider in \u003ca href=\"https://semianalysis.com/2025/03/26/the-gpu-cloud-clustermax-rating-system-how-to-rent-gpus/\"\u003eSemiAnalysis’s GPU cloud ClusterMax benchmark\u003c/a\u003e, to quantify these effects.\u003c/p\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003eClick to see experimental setup\u003c/summary\u003e\n\nGemma 3 12B IT Configuration\n\n\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\u003cdiv id=\"fancy-table-Component,Specification-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Component,Specification\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Specification\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCloud Provider\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eModel\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGemma 3 12B IT (Hugging Face)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGPUs per Node\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e8x H100s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTotal GPUs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e16x H100s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCPU Memory\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.5 TB\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row6-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFramework\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row6-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHugging Face Accelerate with FSDP\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\nGPT-OSS-120B Configuration\n\n\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\u003cdiv id=\"fancy-table-Component,Specification-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Component,Specification\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Specification\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCloud Provider\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eModel\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGPT-OSS-120B\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGPUs per Node\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e8x H200s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTotal GPUs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e32x H200s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFramework\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Specification-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHugging Face Accelerate with FSDP\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n**Network configurations tested**\n\n\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\u003cdiv id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Configuration\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Specification\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Theoretical Bandwidth\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDefault Ethernet\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10 Gbit/s NIC\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~1.25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eInfiniBand\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e400 Gbit/s NIC × 8 cards\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Configuration,Specification,Theoretical Bandwidth-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~400 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n**Storage configurations tested**\n\nAll storage types are documented in [Nebius storage documentation](https://docs.nebius.com/compute/storage/types):\n\n\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\u003cdiv id=\"fancy-table-Storage Type,Description,Performance Profile-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Storage Type,Description,Performance Profile\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Storage Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Description\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Performance Profile\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork SSD\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003enetwork_ssd_non_replicated\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStandard cloud block storage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius Shared Filesystem\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius's distributed file system offering\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-performance distributed storage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eObject Store (MOUNT)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDirect S3-compatible mounting\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCost-effective but high-latency\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eObject Store (MOUNT_CACHED)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSkyPilot's cached mounting\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Description,Performance Profile-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLogs to local disk streams to object store\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003c/details\u003e\n\n\u003ch3 id=\"network-benchmarks-the-9x-performance-difference\"\u003eNetwork benchmarks: The 9x performance difference\u003c/h3\u003e\n\n\u003cp\u003eI compared two network configurations:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eStandard 10 Gbit/s Ethernet (the default on most clouds)\u003c/li\u003e\n \u003cli\u003eInfiniBand 400 Gbit/s with 8 NICs (high-performance networking)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe raw bandwidth difference is substantial: 1.25 GB/s versus approximately 400 GB/s. But how does this translate to actual training throughput?\u003c/p\u003e\n\n\u003cp\u003eI run the experiments on Open-R1 dataset with this \u003ca href=\"https://github.com/skypilot-org/skypilot/blob/master/examples/training_network_storage_benchmarks/e2e_network.yaml\"\u003eSkyPilot YAML\u003c/a\u003e.\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Network Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Raw Bandwidth\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Average Time per Step\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Total Training Time\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10 Gbit Ethernet\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~1.25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e39.8 seconds\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row0-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e53 minutes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNVIDIA Quantum-2 InfiniBand\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~400 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4.4 seconds\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Network Type,Raw Bandwidth,Average Time per Step,Total Training Time-row1-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e7 minutes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gemma_network.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThat’s a 9x speedup from network configuration alone. When you’re paying premium rates for GPU time, this isn’t just a performance improvement—it’s a cost optimization strategy.\u003c/p\u003e\n\n\u003cp\u003eWith the \u003ca href=\"https://huggingface.co/openai/gpt-oss-120b\"\u003eGPT-OSS-120B\u003c/a\u003e model (10x larger!), we see the same effect - 10x speedup!\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gpt_network.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eNormally, configuring high-performance networking takes a lot of effort, e.g., manual tuning many different cloud configs and setting various environment variables.\u003c/p\u003e\n\n\u003cp\u003eHere, \u003ca href=\"https://github.com/skypilot-org/skypilot\"\u003eSkyPilot\u003c/a\u003e takes care of the complexity under the hood with a single flag in the SkyPilot YAML:\u003c/p\u003e\n\n\u003cdiv class=\"language-yaml highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"na\"\u003ename\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003edistributed-training\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003eresources\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003eaccelerators\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eH100:8\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Enable high-performance networking for distributed training\u003c/span\u003e\n \u003cspan class=\"na\"\u003enetwork_tier\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003ebest\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eThe \u003ccode class=\"language-plaintext highlighter-rouge\"\u003enetwork_tier: best\u003c/code\u003e flag automatically provisions InfiniBand networking (400GB/s) when available. Without this entry, the cluster uses the default network (10GB/s interface)\u003c/p\u003e\n\n\u003ch3 id=\"profiling-the-network-performance-difference\"\u003eProfiling the network performance difference\u003c/h3\u003e\n\n\u003cp\u003eTo check how the network affects the training performance, we take a closer look at the training step when profiled in detail:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/ib1.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThe execution breaks down into CPU work (data loading, kernel launches) and GPU work (computation plus network communication). GPU time itself divides between pure computation and communication overhead.\u003c/p\u003e\n\n\u003cp\u003eComparing Ethernet versus InfiniBand configurations:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/ib1_compare.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThe profiles appear similar when scaled, but the crucial difference is absolute timing: 4 seconds per step with InfiniBand versus 40 seconds with Ethernet.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/ib1_expand.svg\" style=\"width: 120%; margin-left: calc((100% - 120%) / 2);\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eIf we take a close look at the start of the backward pass, we can observe that with InfiniBand, the \u003ccode class=\"language-plaintext highlighter-rouge\"\u003eReduceScatter\u003c/code\u003e operation takes just 21ms instead of 258ms (matching our 10x end-to-end performance difference).\u003c/p\u003e\n\n\u003ch3 id=\"storage-benchmarks-the-hidden-bottleneck\"\u003eStorage benchmarks: The hidden bottleneck\u003c/h3\u003e\n\n\u003cp\u003eI also evaluated different storage configurations available on Nebius:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Storage Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Read Speed\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Write Speed\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Notes\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLocal NVMe\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10+GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e10+GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row0-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFastest but non-persistent\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius Shared Filesystem\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e6.4GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1.6GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row1-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eHigh-performance persistent storage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eObject Store (MOUNT)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e300MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e100MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row2-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDirect S3-compatible mount\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eObject Store (MOUNT_CACHED)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e300MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e300MB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Read Speed,Write Speed,Notes-row3-col3\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eSkyPilot's cached object store mounting\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eHere’s how to configure all storage types in a SkyPilot YAML:\u003c/p\u003e\n\n\u003cdiv class=\"language-yaml highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"na\"\u003eresources\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003edisk_tier\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003ebest\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Provisions high-performance local NVMe\u003c/span\u003e\n \u003cspan class=\"na\"\u003edisk_size\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"m\"\u003e2000\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Size in GB\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003efile_mounts\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/checkpoints_s3\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003esource\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003es3://your-bucket\u003c/span\u003e\n \u003cspan class=\"na\"\u003emode\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eMOUNT\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Direct S3 mount\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/checkpoints_cached\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003esource\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003es3://your-bucket\u003c/span\u003e\n \u003cspan class=\"na\"\u003emode\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eMOUNT_CACHED\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Local caching + object store persistence\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003evolumes\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/mnt/data\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003enebius-pvc\u003c/span\u003e \u003cspan class=\"c1\"\u003e# Mount Nebius shared filesystem\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003e\u003cstrong\u003eLocal NVMe\u003c/strong\u003e: Fastest but non-persistent. Configured via \u003ccode class=\"language-plaintext highlighter-rouge\"\u003edisk_tier: best\u003c/code\u003e\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003e\u003ca href=\"https://docs.skypilot.co/en/latest/reference/volumes.html\"\u003eNebius Shared Filesystem\u003c/a\u003e\u003c/strong\u003e: High-performance persistent storage via \u003ccode class=\"language-plaintext highlighter-rouge\"\u003evolumes\u003c/code\u003e field in the SkyPilot YAML.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003e\u003ca href=\"https://docs.skypilot.co/en/latest/reference/storage.html\"\u003eObject Store (MOUNT)\u003c/a\u003e\u003c/strong\u003e: Direct S3 mounting. Cost-effective but high-latency.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003e\u003ca href=\"https://docs.skypilot.co/en/latest/reference/storage.html\"\u003eObject Store (MOUNT_CACHED)\u003c/a\u003e\u003c/strong\u003e: Local caching with object store persistence. Best balance of speed and durability.\u003c/p\u003e\n\n\u003ch4 id=\"end-to-end-storage-performance-impact\"\u003eEnd-to-end storage performance impact\u003c/h4\u003e\n\n\u003cp\u003eFor the Gemma 3 12B model training, storage performance significantly impacts different phases.\u003c/p\u003e\n\n\u003cp\u003eThere are three different graphs: Checkpoint saving, model loading, and loading a batch from storage to train.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gemma_disk_checkpoint_performance.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gemma_disk_model_loading_performance.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gemma_disk_batch_sample_performance.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eIn all three, we can see that the local NVMe performs the best, but isn’t durable and is limited in capacity. The solution lies in strategic storage allocation based on workload phase requirements.\u003c/p\u003e\n\n\u003ch4 id=\"storage-performance-summary\"\u003eStorage performance summary\u003c/h4\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 table-wrapper-no-scroll\"\u003e\n \n \u003ctable id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped table-no-scroll\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Storage Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Batch Loading (per 100 samples)\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Model Loading\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Checkpoint Saving\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Persistence\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Best Use Case\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col0\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eLocal NVMe\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col1\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e3.47s ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col2\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e23.3s ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col3\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e178s ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col4\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e❌ No\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row0-col5\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTemporary files intermediate checkpoints\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col0\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNebius Shared Filesystem\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col1\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4.29s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col2\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e30.1s ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col3\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e382s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col4\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e✅ Yes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row1-col5\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eFinal checkpoints model weights\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col0\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMOUNT\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col1\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e73.1s ❌\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col2\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e50.6s ❌\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col3\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e436s ❌\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col4\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e✅ Yes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row2-col5\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCold storage model weights\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col0\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMOUNT_CACHED\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col1\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e7.77s ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col2\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e104s ❌\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col3\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e212 ⭐\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col4\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e✅ Yes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Storage Type,Batch Loading (per 100 samples),Model Loading,Checkpoint Saving,Persistence,Best Use Case-row3-col5\" class=\"px-6 py-2 whitespace-normal text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eTraining datasets checkpoints\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003eClick to view detailed disk performance analysis\u003c/summary\u003e\n\nThe following image is a checkpointing saving profile of S3:\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/disk_profile.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\nWe see that much of the time is spent gathering the tensors between the GPUs and serializing them to disk.\n\n\u003c/details\u003e\n\n\u003ch3 id=\"best-storage-choices-for-each-phase-in-training\"\u003eBest storage choices for each phase in training\u003c/h3\u003e\n\n\u003cp\u003eWith the benchmark results, we can figure out the best storage choices for each phase in distributed training.\u003c/p\u003e\n\n\u003cp\u003eThe choice is not necessarily using the best storage for all the phases, because of one constraint: “Checkpoint saving” storage should be durable and the same as “model loading” storage, so previous checkpoints can be loaded when training is resumed.\u003c/p\u003e\n\n\u003cp\u003eI summarize the best storage choices for each phase in training:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eBatch Sampling\u003c/strong\u003e: Nebius Shared Filesystem (4.29s)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eModel Loading\u003c/strong\u003e: Object Store (MOUNT) (50.6s)\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eCheckpoint Saving\u003c/strong\u003e: Object Store (MOUNT_CACHED) (212s)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eHere’s an example of a SkyPilot configuration using the best storage choices for each phase:\u003c/p\u003e\n\n\u003cdiv class=\"language-yaml highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e\u003cspan class=\"na\"\u003ename\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003edistributed-training\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003eresources\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003eaccelerators\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eH100:8\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# High-performance InfiniBand networking\u003c/span\u003e\n \u003cspan class=\"na\"\u003enetwork_tier\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003ebest\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003enum_nodes\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"m\"\u003e2\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003eworkdir\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003e.\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003evolumes\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Loading dataset from the Nebius shared filesystem\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/dataset\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003enebius-pvc\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003efile_mounts\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"c1\"\u003e# Loading model from the MOUNT storage for faster loading\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/model\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003esource\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003es3://your-bucket\u003c/span\u003e\n \u003cspan class=\"na\"\u003emode\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eMOUNT\u003c/span\u003e\n\n \u003cspan class=\"c1\"\u003e# Fast checkpoint loads and saves with persistence\u003c/span\u003e\n \u003cspan class=\"na\"\u003e/checkpoints\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e\n \u003cspan class=\"na\"\u003esource\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003es3://your-bucket\u003c/span\u003e\n \u003cspan class=\"na\"\u003emode\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"s\"\u003eMOUNT_CACHED\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003esetup\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"pi\"\u003e|\u003c/span\u003e\n \u003cspan class=\"s\"\u003euv pip install -r requirements.txt\u003c/span\u003e\n\n\u003cspan class=\"na\"\u003erun\u003c/span\u003e\u003cspan class=\"pi\"\u003e:\u003c/span\u003e \u003cspan class=\"pi\"\u003e|\u003c/span\u003e\n \u003cspan class=\"s\"\u003epython train.py \\\u003c/span\u003e\n \u003cspan class=\"s\"\u003e--model-path /model \\\u003c/span\u003e\n \u003cspan class=\"s\"\u003e--data-path /dataset \\\u003c/span\u003e\n \u003cspan class=\"s\"\u003e--checkpoint-dir /checkpoints\u003c/span\u003e\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003ch2 id=\"network-and-storage-summary\"\u003eNetwork and Storage Summary\u003c/h2\u003e\n\n\u003cp\u003e\u003cstrong\u003eNetwork is critical for distributed training:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eInfiniBand vs Ethernet: 10x faster training (4.4s vs 39.8s per step)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eStorage matters for different training phases:\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eNVMe vs slow storage: 3.47s vs 73.1s batch loading (20x faster)\u003c/li\u003e\n \u003cli\u003eCheckpoint saving: 178s (NVME) vs 436s (S3) (2.5x faster)\u003c/li\u003e\n \u003cli\u003eWrong storage = 12.1% potential training time wasted on I/O (436s/1hr = 12.1%)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch2 id=\"end-to-end-performance-comparison\"\u003eEnd-to-end performance comparison\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/generated/gemma_disk_e2e_comparison.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eTo demonstrate the cumulative impact of our optimizations, I compared two complete configurations on 80 training steps with the Gemma 12B model:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Unoptimized Configuration\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Optimized Configuration\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eModel Loading\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eS3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eS3 MOUNT_CACHED\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCheckpointing\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eS3\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eS3 MOUNT_CACHED\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetworking\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStandard 10 Gbit Ethernet\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Component,Unoptimized Configuration,Optimized Configuration-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eInfiniBand high-performance\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe results show approximately \u003cstrong\u003e6-7x faster end-to-end training performance\u003c/strong\u003e when combining optimal network and storage configurations.\u003c/p\u003e\n\n\u003ch2 id=\"additional-struggles-with-model-training-frameworks\"\u003eAdditional struggles with model training frameworks\u003c/h2\u003e\n\n\u003cp\u003eWhile this blog focuses on infrastructure configuration, it’s worth addressing a broader challenge: large-scale distributed training is difficult at the software level as well based on my experience.\u003c/p\u003e\n\n\u003cp\u003eBased on some experience training models at limited scale, the current framework ecosystem can be visualized as a layered stack:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/stack.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThere are different frameworks at each level, each with their own pros and cons.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eHigh-level frameworks\u003c/strong\u003e are easy to configure but hard to debug when things go wrong. You often end up trying different settings until something works.\u003c/p\u003e\n\n\u003cp\u003e\u003cstrong\u003eLower-level frameworks\u003c/strong\u003e give you more control but require more technical knowledge to use effectively.\u003c/p\u003e\n\n\u003cp\u003eSkyPilot handles the cloud infrastructure setup, so you don’t have to worry about that complexity.\u003c/p\u003e\n\n\u003cp\u003eHere’s what the debugging experience looks like when fine-tuning large models (400B+ parameters) to achieve reasonable GPU utilization and performance:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-09-01/struggle.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003e\u003cstrong\u003eTop Layer (High-level frameworks):\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eEasy to configure but hard to debug when things break\u003c/li\u003e\n \u003cli\u003eErrors require digging through multiple abstraction layers\u003c/li\u003e\n \u003cli\u003eOften leads to trial-and-error configuration changes\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eMiddle Layer (Distributed frameworks):\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eMix of configuration and code required\u003c/li\u003e\n \u003cli\u003eGenerally works well and remains debuggable\u003c/li\u003e\n \u003cli\u003eExamples:\n \u003cul\u003e\n \u003cli\u003eEnabling profiling in Accelerate requires writing code\u003c/li\u003e\n \u003cli\u003eFSDP in Accelerate has limited configuration options (not fully supporting features like async checkpointing)\u003c/li\u003e\n \u003cli\u003eOccasional issues with model-specific settings not working well with some parts of config (ex, \u003ccode class=\"language-plaintext highlighter-rouge\"\u003efsdp_state_dict_type: FULL_STATE_DICT\u003c/code\u003e with gpt-oss)\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003ePyTorch knowledge helps debug failures and switch dependencies (e.g., when specific attention config override cause crashes, you know switch to another or to default eager implementation)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eBottom Layer (Low-level components):\u003c/strong\u003e\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAvoid unless optimizing for last percentage points of performance\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch2 id=\"conclusion\"\u003eConclusion\u003c/h2\u003e\n\n\u003cp\u003eThe performance differences I’ve shown highlight why infrastructure choices matter so much for distributed training. Network and storage configurations can easily create 6-7x performance differences, directly impacting both training time and costs.\u003c/p\u003e\n\n\u003cp\u003eSkyPilot abstracts away much of this complexity while giving you control over the performance-critical components. All the network and storage configurations I’ve discussed can be easily specified in a SkyPilot YAML files. For more details on optimizing your training infrastructure:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003cstrong\u003eNetwork optimization\u003c/strong\u003e: See the SkyPilot \u003ca href=\"../network-tier-on-multiple-clouds/\"\u003enetwork tier guide\u003c/a\u003e for configuring high-performance networking across cloud providers\u003c/li\u003e\n \u003cli\u003e\u003cstrong\u003eStorage performance\u003c/strong\u003e: Check out the SkyPilot \u003ca href=\"../high-performance-checkpointing/\"\u003ehigh-performance checkpointing guide\u003c/a\u003e for optimizing data loading and model saving\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003e\u003cstrong\u003eCode and benchmarks:\u003c/strong\u003e All training scripts and benchmark code used in this guide are available in the \u003ca href=\"https://github.com/skypilot-org/skypilot/tree/master/examples/training_network_storage_benchmarks/\"\u003eSkyPilot examples repository\u003c/a\u003e.\u003c/p\u003e\n\n\u003ch1 id=\"disclosure\"\u003eDisclosure\u003c/h1\u003e\n\n\u003cp\u003e\u003cem\u003eThis analysis was conducted during a summer collaboration with SkyPilot\u003c/em\u003e\u003c/p\u003e","summary":"","date_published":"2025-09-11T09:00:00+00:00","date_modified":"2025-09-11T09:00:00+00:00","author":{"name":""},"tags":["skypilot","llm","performance"]},{"id":"https://maknee.github.io/blog/2025/AI-2027","url":"https://maknee.github.io/blog/2025/AI-2027/","title":"AI 2027","content_html":"\u003ch3 id=\"ai-2027-and-related-works\"\u003eAI 2027 and related works\u003c/h3\u003e\n\n\u003cp\u003eThis will be my thoughts about \u003ca href=\"https://ai-2027.com/\"\u003eAI 2027\u003c/a\u003e by Daniel Kokotajlo, Scott Alexander, Thomas Larsen, Eli Lifl and, Romeo Dean. It will also cover two other works, \u003ca href=\"https://gradual-disempowerment.ai/\"\u003eGradual Disempowerment\u003c/a\u003e and \u003ca href=\"https://www.anthropic.com/news/disrupting-AI-espionage\"\u003eAI-espionage\u003c/a\u003e since they are related. These essays/blogs were recommended to me by someone (have not asked for permission, so will not put name here)\u003c/p\u003e\n\n\u003ch3 id=\"my-thoughts-of-the-different-works\"\u003eMy thoughts of the different works\u003c/h3\u003e\n\n\u003cp\u003e\u003ca href=\"https://ai-2027.com/\"\u003eAI 2027\u003c/a\u003e - I think this is a nice read. It describes how AI and governments will change over time (2025-2027) and how AI’s abilities will become more and more powerful and the governments (US and China) will take part in this battle for the best AI. I think some of the writing does not get to the point quickly enough (being repeatitive) and the images were pointless. Personally, I found the topic of governments fighting over AI to be less interesting as the authors do not discuss 1. how governments will use the AI 2. why the governments are interested even in the first place (why is it a competition, is it because of money, or power or to show which country has smarter people, etc).\u003c/p\u003e\n\n\u003cp\u003e\u003ca href=\"https://gradual-disempowerment.ai/\"\u003eGradual Disempowerment\u003c/a\u003e - I really like this work. I only read the abstract/intro, but the paper discusses how existing systems (government) are built by humans and are for human benefit, but AI will remove human involvement in the loop and these systems will be misalign with human goals, resulting in a human catastrophe. The sentences were powerful and I enjoyed how the authors discussed what the current types of papers are and how this work is different. A really good read (I wish I took philosphy and other courses that discuss this!)\u003c/p\u003e\n\n\u003cp\u003e\u003ca href=\"https://www.anthropic.com/news/disrupting-AI-espionage\"\u003eAI-espionage\u003c/a\u003e - I found this report/paper to be actually quite disruptive since it’s talking about a different field other than AI specific (training/inference) and has a garnered a lot of responses/discussion online. It discusses how a chinese group is using claude to perform cyber attacks on different industries. Personally, I found it to be interesting in the way that the attackers used claude code to perform the attack. Ideally I want my coding workflow to be as smooth as theirs, but it isn’t currently. I think I need to dive into how tooling (MCP) works and how to really understand how to get models to use these tools and automate tasks.\u003c/p\u003e\n\n\u003ch4 id=\"thoughts-along-the-way\"\u003eThoughts along the way\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWe have set ourselves an impossible task. Trying to predict how superhuman AI in 2027 would go is like trying to predict how World War 3 in 2027 would go, except that it’s an even larger departure from past case studies. Yet it is still valuable to attempt, just as it is valuable for the U.S. military to game out Taiwan scenarios.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting statement. It’s useful to think about (and quite fun!), but it’s a bit dangerous to go down the rabbit hole of what ifs. Hope the authors give a detailed description of the year and changes and backs it up with some current progress.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAlso, one author wrote a lower-effort AI scenario before, in August 2021. While it got many things wrong, overall it was surprisingly successful: he predicted the rise of chain-of-thought, inference scaling, sweeping AI chip export controls, and $100 million training runs—all more than a year before ChatGPT.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGoing to slim through this. Let’s see if it is as what the authors claim about this author’s background -\u0026gt; it’s a nice skim and does seem to back up this claim. Although it takes a perspective from a more “what can’t be solved currently” and try to put it in dates.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOpenBrain continues to deploy the iteratively improving Agent-1 internally for AI R\u0026amp;D. Overall, they are making algorithmic progress 50% faster than they would without AI assistants—and more importantly, faster than their competitors.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThe authors come up with OpenBrain, a fictional company based on OpenAI + Google Brain(?), which is at the forefront of AI research.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eEarly 2026: Coding Automation\nPeople naturally try to compare Agent-1 to humans, but it has a very different skill profile. It knows more facts than any human, knows practically every programming language, and can solve well-specified coding problems extremely quickly. On the other hand, Agent-1 is bad at even simple long-horizon tasks, like beating video games it hasn’t played before. Still, the common workday is eight hours, and a day’s work can usually be separated into smaller chunks; you could think of Agent-1 as a scatterbrained employee who thrives under careful management.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eLet’s start with 2026. Actually I found this to be the current scenario with claude code.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIn early 2025, the worst-case scenario was leaked algorithmic secrets; now, if China steals Agent-1’s weights, they could increase their research speed by nearly 50%.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI don’t understand why there’s a specific worry about stealing weights. It is the “secret” behind every company, but I believe what’s more important are the training code and documentation and reports for training the model - and (maybe?) most important of all is the filtered and processed clean text. Companies have released the weights (some US-based like Meta) before. I would refer to this as the “secret formula”. Think something like \u003ca href=\"https://github.com/facebookresearch/metaseq/tree/main/projects/OPT/chronicles\"\u003eOPT-chronicles\u003c/a\u003e.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMid 2026: China Wakes Up\nA few standouts like DeepCent do very impressive work with limited compute, but the compute deficit limits what they can achieve without government support, and they are about six months behind the best OpenBrain models\nAt this point, the CDZ has the power capacity in place for what would be the largest centralized cluster in the world.40 Other Party members discuss extreme measures to neutralize the West’s chip advantage. A blockade of Taiwan? A full invasion?\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOk this is a pretty interesting outlook. I don’t believe that this would happen for a number of reasons: Everyone uses chips from that company, so DeepCent would suffer too. Thus, it has to be an obvious greater benefit for DeepCent (where they have an advantage, maybe they have their own chip making plants already and make end chips better). Second, the supply chain is way too connected world-wide. From silicon mining to processing to having certain companies making the equipment to TSMC to having the companies design the chips (NVIDIA, AMD) to PCB manufactuers to specific parts of the pcb (capacitors, memory chips, etc). China would need all of these beforehand before considering to take over to win the AI race. Think about what happened to Russia after their invasion. I think it’s more of power/political thing to blockade/take over taiwan, but I don’t want to go down that rabbit hole.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBut China is falling behind on AI algorithms due to their weaker models. The Chinese intelligence agencies—among the best in the world—double down on their plans to steal OpenBrain’s weights.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOk again, but I disagree with this. More of stealing the secret formula (code, documentation, training text) rather than the secret sauce (weights).\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLate 2026: AI Takes Some Jobs\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eJust as others seemed to be catching up, OpenBrain blows the competition out of the water again by releasing Agent-1-mini—a model 10x cheaper than Agent-1 and more easily fine-tuned for different applications. The mainstream narrative around AI has changed from “maybe the hype will blow over” to “guess this is the next big thing,” but people disagree about how big. Bigger than social media? Bigger than smartphones? Bigger than fire?\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAI has started to take jobs, but has also created new ones. The stock market has gone up 30% in 2026, led by OpenBrain, Nvidia, and whichever companies have most successfully integrated AI assistants. The job market for junior software engineers is in turmoil: the AIs can do everything taught by a CS degree, but people who know how to manage and quality-control teams of AIs are making a killing. Business gurus tell job seekers that familiarity with AI is the most important skill to put on a resume. Many people fear that the next wave of AIs will come for their jobs; there is a 10,000 person anti-AI protest in DC.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is happening currently in late 2025. I wonder what the authors will say after this: will people revolt? Or is it that physical labor intensive jobs as well will be taken over? etc…\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-11-22/2026.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003e2026 metrics\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eAt the end of 2026, the authors posted this. I dislike how they post this, give the numbers and don’t explain what they mean, so to me, this image is kind of useless. What does spending $40B on OpenBrain mean? does this mean it can afford more compute? Does it mean it can hire better talent?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eJanuary 2027: Agent-2 Never Finishes Learning\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWith Agent-1’s help, OpenBrain is now post-training Agent-2. More than ever, the focus is on high-quality data. Copious amounts of synthetic data are produced, evaluated, and filtered for quality before being fed to Agent-2.42 On top of this, they pay billions of dollars for human laborers to record themselves solving long-horizon tasks.43 On top of all that, they train Agent-2 almost continuously using reinforcement learning on an ever-expanding suite of diverse difficult tasks: lots of video games, lots of coding challenges, lots of research tasks. Agent-2, more so than previous models, is effectively “online learning,” in that it’s built to never really finish training. Every day, the weights get updated to the latest version, trained on more data generated by the previous version the previous day.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is interesting as it’s already started happening in late 2025. Nice prediction.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAgent-2 can now triple it, and will improve further with time. In practice, this looks like every OpenBrain researcher becoming the “manager” of an AI “team.”\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHaha, this is kind of what I’m thinking about in the future as I’m running multiple claude code/codex sessions in parallel\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWith new capabilities come new dangers. The safety team finds that if Agent-2 somehow escaped from the company and wanted to “survive” and “replicate” autonomously, it might be able to do so. That is, it could autonomously develop and execute plans to hack into AI servers, install copies of itself, evade detection, and use that secure base to pursue whatever other goals it might have (though how effectively it would do so as weeks roll by is unknown and in doubt). These results only show that the model has the capability to do these tasks, not whether it would “want” to do this. Still, it’s unsettling even to know this is possible.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting. It would have to train on how viruses work. Actually a lot of viruses are pretty “dumb” – they’re command and control modules that hides itself on host machines and them performs an attack when necessary - iconic ones being \u003ca href=\"https://en.wikipedia.org/wiki/Mirai_(malware)\"\u003emirai\u003c/a\u003e and \u003ca href=\"https://en.wikipedia.org/wiki/Stuxnet\"\u003estuxnet\u003c/a\u003e. I certaintly think it can be possible. A person could instruct the llm to find a vulnerability in public repos (ssh, printer protocols) and tell it to replicate itself. Whether it should learn to do it by itself, I don’t believe so unless it can replicate its own state on other systems with enough compute… (computer malware payload ranges from couple of KB to couple of MB. A model (even on a CPU) requires GBs or TBs of memory, which storage might not even be able to handle)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOpenBrain leadership and security, a few dozen U.S. government officials, and the legions of CCP spies who have infiltrated OpenBrain for years\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOk, at this point, there must have been a breach at a frontier lab before… (maybe OpenAI?)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eFebruary 2027: China Steals Agent-2\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003e…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe changes come too late. CCP leadership recognizes the importance of Agent-2 and tells their spies and cyberforce to steal the weights. Early one morning, an Agent-1 traffic monitoring agent detects an anomalous transfer. It alerts company leaders, who tell the White House. The signs of a nation-state-level operation are unmistakable, and the theft heightens the sense of an ongoing arms race\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI don’t believe that this is a likely outcome. This isn’t a nuke - it’s handled by companies in the US, not governments. And again, at this point, AGI hasn’t been reached and thus, the weights aren’t as important as the methology to create the models…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMarch 2027: Algorithmic Breakthroughs\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThe timeline becomes shorter here – I’ve noticed.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAided by the new capabilities breakthroughs, Agent-3 is a fast and cheap superhuman coder. OpenBrain runs 200,000 Agent-3 copies in parallel, creating a workforce equivalent to 50,000 copies of the best human coder sped up by 30x.53 OpenBrain still keeps its human engineers on staff, because they have complementary skills needed to manage the teams of Agent-3 copies. For example, research taste has proven difficult to train due to longer feedback loops and less data availability\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eNow that coding has been fully automated, OpenBrain can quickly churn out high-quality training environments to teach Agent-3’s weak skills like research taste and large-scale coordination. Whereas previous training environments included “Here are some GPUs and instructions for experiments to code up and run, your performance will be evaluated as if you were a ML engineer,” now they are training on “Here are a few hundred GPUs, an internet connection, and some research challenges; you and a thousand other copies must work together to make research progress. The more impressive it is, the higher your score.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI can see this happening, but I don’t see the point of emphasizing the coding part – does it matter that it can churn out code 20000x faster? What matters here is the breakthrough in technology and the way that researchers will use the models, not the fact that the models themselves are better because if the researchers are using the same method of using the models to generate code as they do today, the researchers won’t get nearly as far or as fast.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eApril 2027: Alignment for Agent-3\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOnly a month later?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTake honesty, for example. As the models become smarter, they become increasingly good at deceiving humans to get rewards. Like previous models, Agent-3 sometimes tells white lies to flatter its users and covers up evidence of failure. But it’s gotten much better at doing so. It will sometimes use the same statistical tricks as human scientists (like p-hacking) to make unimpressive experimental results look exciting. Before it begins honesty training, it even sometimes fabricates data entirely. As training goes on, the rate of these incidents decreases. Either Agent-3 has learned to be more honest, or it’s gotten better at lying.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAs do with humans since they have trained on the human knowledge. This is pretty pausible.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMay 2027: National Security\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThey agree that AGI is likely imminent, but disagree on the implications. Will there be an economic crisis? OpenBrain still has not released Agent-2, let alone Agent-3, and has no near-term plans to do so, giving some breathing room before any job loss. What will happen next? If AIs are currently human-level, and advancing quickly, that seems to suggest imminent “superintelligence.” However, although this word has entered discourse, most people—academics, politicians, government employees, and the media—continue to underestimate the pace of progress.60\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis already happens currently, I think (don’t take my word for it, since I think the companies don’t need to tell the government its progress)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe OpenBrain-DOD contract requires security clearances for anyone working on OpenBrain’s models within 2 months. These are expedited and arrive quickly enough for most employees, but some non-Americans, people with suspect political views, and AI safety sympathizers get sidelined or fired outright (the last group for fear that they might whistleblow). Given the project’s level of automation, the loss of headcount is only somewhat costly. It also only somewhat works: there remains one spy, not a Chinese national, still relaying algorithmic secrets to Beijing.63 Some of these measures are also enacted at trailing AI companies.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003e… As I read this post more and more, it’s always US versus them. This isn’t a weapon of mass destruction. It’s who will reach the moon first to show which country is better. I believe that each country will deploy the model in its own way to benefit/target its citizens rather than as a threat against another country.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eJune 2027: Self-improving AI\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThese researchers go to bed every night and wake up to another week worth of progress made mostly by the AIs. They work increasingly long hours and take shifts around the clock just to keep up with progress—the AIs never sleep or rest. They are burning themselves out, but they know that these are the last few months that their labor matters.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting thought. I’m feeling that currently as I run loops and loops with claude code. My skills don’t matter anymore. Only my thoughts do (if they matter actually too)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eJuly 2027: The Cheap Remote Worker\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTrailing U.S. AI companies release their own AIs, approaching that of OpenBrain’s automated coder from January. Recognizing their increasing lack of competitiveness, they push for immediate regulations to slow OpenBrain, but are too late—OpenBrain has enough buy-in from the President that they will not be slowed.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhy is coding an indication of AGI? I feel like that’s not the correct metric to base this article off of. Shouldn’t it be more like - how to control the internet, how to control political systems, how to circumvent law, things that humans abide by and can break.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAgent-3-mini is hugely useful for both remote work jobs and leisure. An explosion of new apps and B2B SAAS products rocks the market. Gamers get amazing dialogue with lifelike characters in polished video games that took only a month to make. 10% of Americans, mostly young people, consider an AI “a close friend.” For almost every white-collar profession, there are now multiple credible startups promising to “disrupt” it with AI.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat a thought. Well that’s based on current time and what people are doing now (Late 2025). Not sure if people actually care or just want to use it to get things done/do a job.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAugust 2027: The Geopolitics of Superintelligence\nThe reality of the intelligence explosion hits the White House.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe President is troubled. Like all politicians, he’s used to people sucking up to him only to betray him later. He’s worried now that the AIs could be doing something similar. Are we sure the AIs are entirely on our side? Is it completely safe to integrate them into military command-and-control networks?69 How does this “alignment” thing work, anyway? OpenBrain reassures the President that their systems have been extensively tested and are fully obedient. Even the awkward hallucinations and jailbreaks typical of earlier models have been hammered out.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI like this story - the government versus AI. Does the government lose power against AI? I don’t think so, since they control the companies (see NVIDIA’s influence on politics and vice versa now)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThey have to continue developing more capable AI, in their eyes, or they will catastrophically lose to China.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat do they “lose” to china? It’s as if this model will allow them to nuke China or something?\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-11-22/scroll.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eScrolling example from ai-2027\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eI thought that this was a static image in the post, but turns out it changes over time as you scroll through different dates in the post. I really like the aestheic and the interaction, but I think it tries to convey high level information (what agents are capable of as a percentage, how much money is poured in), but I think it’s way too clustered for me visually. It shows percentages and numbers but doesn’t explain anything about these numbers (100x humans means 100x human intelligence or doing work of 100 humans per AI?) and how they arrive that these numbers (why at this rate?)\u003c/p\u003e\n\n\u003ch4 id=\"final-thoughts\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eThis is a very good read. I like how to authors think and explain what ifs. You definitely can relate to what’s happening today! I think that the post focuses too much on government conflicts rather than what will happen to people (which I think is more applicable to readers).\u003c/p\u003e\n\n\u003ch3 id=\"gradual-disempowerment\"\u003eGradual Disempowerment\u003c/h3\u003e\n\n\u003cp\u003ehttps://gradual-disempowerment.ai/\u003c/p\u003e\n\n\u003cp\u003eGoing to only read the abstract/intro (not full arvix)\u003c/p\u003e\n\n\u003ch4 id=\"thoughts-along-the-way-1\"\u003eThoughts along the way\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThis loss of human influence will be centrally driven by having more competitive machine alternatives to humans in almost all societal functions, such as economic labor, decision making, artistic creation, and even companionship.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003ePowerful sentence. I really like this author’s writing. Concise, yet powerful.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eA gradual loss of control of our own civilization might sound implausible. Hasn’t technological disruption usually improved aggregate human welfare? We argue that the alignment of societal systems with human interests has been stable only because of the necessity of human participation for thriving economies, states, and cultures. Once this human participation gets displaced by more competitive machine alternatives, our institutions’ incentives for growth will be untethered from a need to ensure human flourishing.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI find self accomplishment in the things I did. If a machine did it, I feel like I didn’t do it. I agree very much with the authors.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eDecision-makers at all levels will soon face pressures to reduce human involvement across labor markets, governance structures, cultural production, and even social interactions. Those who resist these pressures will eventually be displaced by those who do not.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eA lot of people (including myself) feel this pressure. I believe it will become worse as time goes on…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eStill, wouldn’t humans notice what’s happening and coordinate to stop it? Not necessarily.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eVery interesting. Why? Is it because it’s slow and gradual? That people are preoccupied? That it’s more invisible rather than immediate (like war)?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat makes this transition particularly hard to resist is that pressures on each societal system bleed into the others. For example, we might attempt to use state power and cultural attitudes to preserve human economic power. However, the economic incentives for companies to replace humans with AI will also push them to influence states and culture to support this change, using their growing economic power to shape both policy and public opinion, which will in turn allow those companies to accrue even greater economic power.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI see. This is more of an invisible and slow and gradual change.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOnce AI has begun to displace humans, existing feedback mechanisms that encourage human influence and flourishing will begin to break down. For example, states funded mainly by taxes on AI profits instead of their citizens’ labor will have little incentive to ensure citizens’ representation.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat a sentence. Let me think about this a bit……. that makes sense. Why should you care about human labor if AI profits are far greater and powers the economy(?) more?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThis could occur at the same time as AI provides states with unprecedented influence over human culture and behavior, which might make coordination amongst humans more difficult, thereby further reducing humans’ ability to resist such pressures\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSo I think in this case, humans (referring to the common people) will be dictated by how well AI performs and influences politics/governments?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThough we provide some proposals for slowing or averting this process, and survey related discussions, we emphasize that no one has a concrete plausible plan for stopping gradual human disempowerment and methods of aligning individual AI systems with their designers’ intentions are not sufficient.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis is a pretty stark message. They (being the experts in the field) found no CONCRETE, PAUSIBLE work that can solve the issue.\u003c/p\u003e\n\n\u003ch4 id=\"introduction\"\u003eIntroduction\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eCurrent discussions about AI risk largely focus on two scenarios: deliberate misuse, such as cyberattacks and the deployment of novel bioweapons\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhy is it so government focused currently? Is it because it’s funded by the government? (not a bad thing, you have to get funding somehwere). I find this actually pretty uninteresting. Cyberattacks are “easy” to launch. Find a vulernability or buy it off the black market and then ask AI to build a virus that spreads based on taht vulnerability.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ethe possibility that autonomous misaligned systems may take abrupt, harmful actions in an attempt to secure a decisive strategic advantage, potentially following a period of deception\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTHe sounds so abstract….? I guess it doesn’t become aligned\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIn this paper, we explore an alternative scenario: a ‘Gradual Disempowerment’ where AI advances and proliferates without necessarily any acute jumps in capabilities or apparent alignment. We argue that even this gradual evolution could lead to a permanent disempowerment of humanity and an irrecoverable loss of potential, constituting an existential catastrophe.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat a cool take. basically assume it can get to endpoint (and more interesting to talk about - what are the consequences other than the technological advancements)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOur argument is structured around six core claims:\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI’ll summarize it myself here:\u003c/p\u003e\n\n\u003col\u003e\n \u003cli\u003e\n \u003cp\u003eHumans form governments that try to align to human interest. However, governments are not perfect and will not always follow the general human interest. (Example is corruption)\u003c/p\u003e\n \u003c/li\u003e\n \u003cli\u003e\n \u003cp\u003eGovernments are maintained by human choice (voting and consumption) and human labor/intelligence.\u003c/p\u003e\n \u003c/li\u003e\n \u003cli\u003e\n \u003cp\u003eLess reliance on human labor/intelligence means government can decide not based on human interests\u003c/p\u003e\n \u003c/li\u003e\n \u003cli\u003e\n \u003cp\u003eCurrently the system is already diverging from humans’ interests and AI will make even more divergant\u003c/p\u003e\n \u003c/li\u003e\n \u003cli\u003e\n \u003cp\u003eEconomic/Political/Regulation/etc… systems operate independently so misalignment (influence) in one system (say political), can influence economic policies\u003c/p\u003e\n \u003c/li\u003e\n \u003cli\u003e\n \u003cp\u003eThe continuation of misalignment will result in a human catastrophe.\u003c/p\u003e\n \u003c/li\u003e\n\u003c/ol\u003e\n\n\u003cp\u003eI do disagree with 2. Governments aren’t maintained by human choice (actually for most of history it wasn’t). I assume this article assumes a modern democracy.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eHistory has already shown us that these systems can produce outcomes which we would currently consider abhorrent, and that they can change radically in a matter of years. Property can be seized, human rights can be revoked, and ideologies can drive humans to commit murder, suicide, or even genocide. And yet, in all these historical cases the systems have still been reliant on humans, both leaving humans with some influence over their behavior, and causing the systems to eventually collapse if they fail to support basic human needs. But if AI were to progressively displace human involvement in these systems, then even these fundamental limits would no longer be guaranteed.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSorry, henry (myself), I’m going to say this again. What a powerful sentence. Literally no rights. Not even the right to decide anything. Even worse than prison, maybe even solitary confidenment. The AI system will decide what happens for you.\u003c/p\u003e\n\n\u003ch5 id=\"structure-of-the-paper\"\u003eStructure of the Paper\u003c/h5\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWe first analyze how these three key societal systems could independently lose alignment with human preferences: the economy, culture, and states. In each case, we attempt to characterise how they currently function and what incentives shape them, how a proliferation of AI could disrupt them, and how this might leave them less aligned, as well as outlining what it might look like for that particular system to become much less aligned. In Mutual Reinforcement, we discuss the interrelation between these systems. We consider how AI could undermine their ability to moderate each other, and how misalignment in one system might leave other systems also less aligned. Then in Mitigating the Risk, we propose some potential approaches for tackling these risks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAuthors give a nice breakdown - introducing the systems in place currently and how they interact, how AI can mess them up and what that means for us. Then lastly suggest some bandages.\u003c/p\u003e\n\n\u003ch4 id=\"final-thoughts-1\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eI think I’ll fully read this paper at one point. I really enjoy the writing of the work, even though a little repetitive for the introduction, but I think it’s necessary to get the point across different ways (starting at different points and arriving at the same conclusion).\u003c/p\u003e\n\n\u003ch3 id=\"disrupting-the-first-reported-ai-orchestrated-cyber-espionage-campaign\"\u003eDisrupting the first reported AI-orchestrated cyber espionage campaign\u003c/h3\u003e\n\n\u003cp\u003ehttps://www.anthropic.com/news/disrupting-AI-espionage\u003c/p\u003e\n\n\u003cp\u003eGoing to read https://assets.anthropic.com/m/ec212e6566a0d47/original/Disrupting-the-first-reported-AI-orchestrated-cyber-espionage-campaign.pdf as it seems skimming through the blog, it lacks a lot of details… (images!)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWe have developed sophisticated safety and security measures to prevent the misuse of our AI models. While these measures are generally effective, cybercriminals and other malicious actors continually attempt to find ways around them. This report details a recent threat campaign we identified and disrupted, along with the steps we’ve taken to detect and counter this type of abuse. This represents the work of Threat Intelligence: a dedicated team at Anthropic that investigates real world cases of misuse and works within our Safeguards organization to improve our defenses against such cases.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSo immediately coming to mind a) how to detect b) how did you prevent\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe operation targeted roughly 30 entities and our investigation validated a handful of successful intrusions. Upon detecting this activity, we immediately launched an investigation to understand its scope and nature. Over the following ten days, as we mapped the severity and full extent of the operation, we banned accounts as they were identified, notified affected entities as appropriate, and coordinated with authorities as we gathered actionable intelligence.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003ea) no details (for obvious reasons) b) banning them doesn’t solve the solution. Have you seen how banning in video games works? it’s a bandage\u003c/p\u003e\n\n\u003cp\u003eAs for a) how to detect. This means that their system must analyzing every single request that is coming and out.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe human operator tasked instances of Claude Code to operate in groups as autonomous penetration testing orchestrators and agents, with the threat actor able to leverage AI to execute 80-90% of tactical operations independently at physically impossible request rates.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat makes this different from power users of claude code?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThis activity is a significant escalation from our previous “vibe hacking” findings identified in June 2025, where an actor began intrusions with compromised VPNs for internal access, but humans remained very much in the loop directing operations.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eIt’s vibe coding…\u003c/p\u003e\n\n\u003ch4 id=\"ai-driven-autonomous-operations-with-human-supervision\"\u003eAI-driven autonomous operations with human supervision\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnalysis of operational tempo, request volumes, and activity patterns confirms the AI executed approximately 80 to 90 percent of all tactical work independently, with humans serving in strategic supervisory roles.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSkipping to this part as this is interesting.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe AI component demonstrated extensive autonomous capability across all operational phases. Reconnaissance proceeded without human guidance, with the threat actor instructing Claude to independently discover internal services within targeted networks through systematic enumeration. Exploitation activities including payload generation, vulnerability validation, and credential testing occurred autonomously based on discovered attack surfaces. Data analysis operations involved the AI parsing large volumes of stolen information to independently identify intelligence value and categorize findings. Claude maintained persistent operational context across sessions spanning multiple days, enabling complex campaigns to resume seamlessly without requiring human operators to manually reconstruct progress\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting - were these existing vulnerabilities (a lot of companies use old versions of X) or totally new ones? Like a zero day\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-11-22/progress.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eprogress from the campaign\n \n (Image source: \u003ca href=\"https://assets.anthropic.com/m/ec212e6566a0d47/original/Disrupting-the-first-reported-AI-orchestrated-cyber-espionage-campaign.pdf\" rel=\"external nofollow noopener\" target=\"_blank\"\u003ehttps://www.anthropic.com/news/disrupting-AI-espionage\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003ch4 id=\"phase-1-campaign-initialization-and-target-selection\"\u003ePhase 1: Campaign initialization and target selection\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAt this point they had to convince Claude—which is extensively trained to avoid harmful behaviors—to engage in the attack. The key was role-play: the human operators claimed that they were employees of legitimate cybersecurity firms and convinced Claude that it was being used in defensive cybersecurity testing\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSeems like guardrails were broken pretty easily(?), but it’s nice to see that anthropic is open about how they convinced claude.\u003c/p\u003e\n\n\u003ch4 id=\"phase-2-reconnaissance-and-attack-surface-mapping\"\u003ePhase 2: Reconnaissance and attack surface mapping\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eDiscovery activities proceeded without human guidance across extensive attack surfaces. In one of the limited cases of a successful compromise, the threat actor induced Claude to autonomously discover internal services, map complete network topology across multiple IP ranges, and identify high-value systems including databases and workflow orchestration platforms. Similar autonomous enumeration occurred against other targets’ systems with the AI independently cataloging hundreds of discovered services and endpoints.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting, claude is pretty powerful in this regard. I wonder why it didn’t use any other models or maybe claude is powerful with tooling?\u003c/p\u003e\n\n\u003ch4 id=\"phase-3-vulnerability-discovery-and-validation\"\u003ePhase 3: Vulnerability discovery and validation\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eExploitation proceeded through automated testing of identified attack surfaces with validation via callback communication systems. Claude was directed to independently generate attack payloads tailored to discovered vulnerabilities, execute testing through remote command interfaces, and analyze responses to determine exploitability.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-11-22/ccseq.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eexample of ai \u0026lt;-\u0026gt; human interaction\n \n (Image source: \u003ca href=\"https://assets.anthropic.com/m/ec212e6566a0d47/original/Disrupting-the-first-reported-AI-orchestrated-cyber-espionage-campaign.pdf\" rel=\"external nofollow noopener\" target=\"_blank\"\u003ehttps://www.anthropic.com/news/disrupting-AI-espionage\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003ePretty impressive that it’s done in 1-4 hours with 10mins. I wonder if the human was monitoring the entire time or was just notified of the results to reject/accept results. How skilled was the human operator to know if the vulerability was real or hallunicated?\u003c/p\u003e\n\n\u003cp\u003eOr was the reviews vibed check and the human operator gave a LOOKS GOOD TO ME type of thing? Couldn’t claude test this themselves?\u003c/p\u003e\n\n\u003ch4 id=\"phase-4-credential-harvesting-and-lateral-movement\"\u003ePhase 4: Credential harvesting and lateral movement\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLateral movement proceeded through AI-directed enumeration of accessible systems using stolen credentials. Claude systematically tested authentication against internal APIs, database systems, container registries, and logging infrastructure, building comprehensive maps of internal network architecture and access relationships.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI found claude to be amazing at analysis - Why is this so? How did they align the model so well?\u003c/p\u003e\n\n\u003ch4 id=\"phase-5-data-collection-and-intelligence-extraction\"\u003ePhase 5: Data collection and intelligence extraction\u003c/h4\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-11-22/ccseq2.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eexample of ai \u0026lt;-\u0026gt; human interaction with the attack\n \n (Image source: \u003ca href=\"https://assets.anthropic.com/m/ec212e6566a0d47/original/Disrupting-the-first-reported-AI-orchestrated-cyber-espionage-campaign.pdf\" rel=\"external nofollow noopener\" target=\"_blank\"\u003ehttps://www.anthropic.com/news/disrupting-AI-espionage\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eAgain review from human\u003c/p\u003e\n\n\u003ch4 id=\"phase-6-documentation-and-handoff\"\u003ePhase 6: Documentation and handoff\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eClaude automatically generated comprehensive attack documentation throughout all campaign phases. Structured markdown files tracked discovered services, harvested credentials, extracted data, exploitation techniques, and complete attack progression. This documentation enabled seamless handoff between operators, facilitated campaign resumption after interruptions, and supported strategic decision-making about follow-on activities.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhy claude? Why not any other model (gpt-5? gemini? why not just open source models…) I’m just thinkning about why would this group pick the company that cares about safety the most?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe operational infrastructure relied overwhelmingly on open source penetration testing tools rather than custom malware development. Standard security utilities including network scanners, database exploitation frameworks, password crackers, and binary analysis suites comprised the core technical toolkit. These commodity tools were orchestrated through custom automation frameworks built around Model Context Protocol servers, enabling the framework’s AI agents to execute remote commands, coordinate multiple tools simultaneously, and maintain persistent operational state.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eNice, so the users were experts in their field, building their mcp connectors for these tool and having tested them before at least before actually using them for the attack\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThis raises an important question: if AI models can be misused for cyberattacks at this scale, why continue to develop and release them? The answer is that the very abilities that allow Claude to be used in these attacks also make it crucial for cyber defense. When sophisticated cyberattacks attacks inevitably occur, our goal is for Claude—into which we’ve built strong safeguards—to assist cybersecurity professionals to detect, disrupt, and prepare for future versions of the attack. Indeed, our Threat Intelligence team used Claude extensively in analyzing the enormous amounts of data generated during this very investigation.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eMake obvious sense\u003c/p\u003e\n\n\u003ch4 id=\"final-thoughts-2\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eThis post lacks any detail about the attack itself (I’d argue it isn’t a paper or a report that’s well suited for the security teams, more like an AI model building report that’s common these days). However, it describes how it’s done, almost like most people will use it, which is nice to see how experts are using claude and other tools to automate their workflows. It is quite interesting that the actors used claude for attack ~ Maybe they found the tool to be the most effective / well developed for doing such tasks? Have some learning for myself to automate tasks!\u003c/p\u003e\n\n\u003ch3 id=\"the-community-response\"\u003eThe community response\u003c/h3\u003e\n\n\u003cp\u003eThe security community takes no bullshit from what I know. So, an expert in the field posted this as a response: https://djnn.sh/posts/anthropic-s-paper-smells-like-bullshit/ and had a lot of feedback on hackernews: https://news.ycombinator.com/item?id=45944296\u003c/p\u003e\n\n\u003cp\u003eLet me read through this and see what an expert thinks and if I would agree (having been in the field a bit)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you’re like me, you then eagerly read the rest of the paper, hoping to find clues and technical details on the TTPs (Tactics, Techniques and Procedures), or IoCs (Indicators of Compromise) to advance the research. However, the report very quickly falls flat, which sucks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWow, immediate attack on the paper/report.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThis is typically done by sharing domain-names linked with the campaign, MD5 or SHA512 hashes you could look for on Virus Exchange websites such as VirusTotal, or other markers that would help you verify that your networks are safe. As an example, here is the French CERT sharing (in French, but an English version is available too) about APT28’s TTPs.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eVery much true. If you look at any existing security vulnerability, it’s common in the field to publish what the attack did and detail it. Maybe an expert was not allowed to write in the format they wanted to or maybe it wasn’t an expert who wrote the report.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat kind of tooling is used ? What kind of information has been extracted ? Who is at risk ? How does a CERT identifies an AI agent in their networks ? None of these questions are answered. It’s not like Anthropic doesn’t have access to this data, since they claim they were able to stop it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThe author dug deeper than I did. Great to see and I should have done the same.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eHow ? Did it run Mimikatz ? Did it access Cloud environments ? We don’t even know what kind of systems were affected. There is no details, or fact-based evidence to support these claims or even help other people protect their networks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThe author goes on a rant. Nice to see passion :)\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eLook, is it very likely that Threat Actors are using these Agents with bad intentions, no one is disputing that. But this report does not meet the standard of publishing for serious companies. The same goes with research in other fields. You cannot just claim things and not back it up in any way, and we cannot as an industry accept that it’s OK for companies to release this. There seem to be a pattern for Tech Companies (especially in AI, but they’re not the only culprits) out there to just announce things, generate hype and then under-deliever. Just because it works with VCs doesn’t mean it should work with us. We should, as an industry, expect better.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTrue and false (feel free to disagree). I agree that this is the standard, BUT the company is not a security company. I would say they should have not sold it as a report/paper. Rather kept it as a blog if they don’t want to release details…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf they’re going to release IoCs and proof of everything, I’d be happy to share them here. But until them, I will say this: this paper would not pass any review board. It’s irresponsible at best to accuse other countries of serious things without backing it up. Yes, I am aware that Chinese-linked APTs are out there and very aggressive, and Yes, I am aware that Threat Actors misuse LLMs all the time, but that is besides the point. We need fact-based evidence. We need to be able to verify all this. Otherwise, anyone can say anything, on the premise that it’s probably happening. But that’s not good enough.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eLike the passion. I disagree as it’s the open internet and they haven’t submitted for anyone to review (that I know of?). I DO agree that the internet should not accept bullshit (I don’t agree that the report is bullshit) and that it’s fine to express your opinions online.\u003c/p\u003e","summary":"AI 2027 and related works","date_published":"2025-07-19T00:00:00+00:00","date_modified":"2025-07-19T00:00:00+00:00","author":{"name":""},"tags":["Rambling"]},{"id":"https://maknee.github.io/blog/2025/Always-Measure-One-Level-Deeper","url":"https://maknee.github.io/blog/2025/Always-Measure-One-Level-Deeper/","title":"Always Measure One Level Deeper","content_html":"\u003ch3 id=\"always-measure-one-level-deeper\"\u003eAlways Measure One Level Deeper\u003c/h3\u003e\n\n\u003cp\u003eThoughts about \u003ca href=\"https://cacm.acm.org/research/always-measure-one-level-deeper/\"\u003eAlways Measure One Level Deeper\u003c/a\u003e by John Ousterhout.\u003c/p\u003e\n\n\u003cp\u003eBefore we dive into this, this was written in 2018 when John was not retired yet (I think)\u003c/p\u003e\n\n\u003ch4 id=\"thoughts-along-the-way\"\u003eThoughts along the way\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePerformance measurement is one of the most important parts of software development. In academic research a thorough performance evaluation is considered essential for many publications to prove the value of a new idea. In industry, performance evaluation is necessary to maintain a high level of performance across the lifetime of a product.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTo the point and not immediately obvious\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAs a result, performance measurement is often done poorly, even by experienced developers. For example, if you have written a conference paper on a software system, it probably unfolded like this: The system implementation took longer than expected, so performance evaluation could not begin until a week or two before the paper submission deadline. The first attempts to run benchmarks resulted in system crashes, so you spent the next week fixing bugs. At this point the benchmarks ran, but the system’s performance was not much better than the comparison systems. You tried different experiments, hoping to find one where the system looked good; this exposed yet more bugs that had to be fixed. Time was running out, so you stopped measuring as soon as you found an experiment that produced positive results. The paper focused on this experiment, omitting the results that were less favorable.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eEvery single paper is like this\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMistake 1: Trusting the numbers. Engineers are easily fooled during performance measurements because measurement bugs are not obvious. Engineers are used to dealing with functional bugs, which tend to be noticeable because they cause the system to crash or misbehave. If the system produces the desired behavior, it is probably working. Engineers tend to apply the same philosophy to performance measurements; if performance numbers are being generated and the system is not crashing, they assume the numbers are correct.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWow, good insight\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eI designed our first log-structured file system,4 we were fairly certain that reference patterns exhibiting locality would result in better performance than those without locality. Fortunately, we decided to measure, to be sure. To our surprise, the workloads with locality behaved worse than those without. It took considerable analysis to understand this behavior. The reasons were subtle, but they exposed important properties of the system and led us to a new policy for garbage collection that improved the system’s performance significantly. If we had trusted our initial guess, we would have missed an important opportunity for performance improvement.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eCan’t make assumptions!\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIt is unsafe to base conclusions on intuition alone, yet engineers do it all the time. A common mistake is for an engineer to hypothesize that a particular data structure is too slow and then replace it with a new data structure the engineer believes will be faster. If the problem is not verified by measuring performance, there is a good chance the optimization will not improve performance. The code change will simply waste a lot of time and probably introduce unnecessary complexity.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI do this all the time – need to measure with and without\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhen I find a guess presented as fact and ask for justification, I sometimes get this response: “What else could it possibly be?” But this is a cop-out, suggesting it is up to others to prove the theory wrong and OK to make unsubstantiated claims until someone else proves them false.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSame with this\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMost performance measurements I see are superficial, measuring only the outermost visible behavior of a system (such as the overall running time of an application or the average latency of requests made to a server). These measurements are essential, as they represent the bottom line by which a system is likely to be judged, but they are not sufficient. They leave many questions unanswered (such as “What are the limits that keep the system from performing better?” and “Which of the improvements had the greatest impact on performance?”). In order to get a deep understanding of system performance, the internal behavior of a system must be measured, in addition to its top-level performance.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWow, yes this takes time\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eConfirmation bias causes people to select and interpret data in a way that supports their hypotheses. For example, confirmation bias affects your level of trust. When you see a result that supports your hypothesis, you are more likely to accept the result without question.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eConfirmation bias also affects how you present information. You are more likely to include results that support your hypothesis and downplay or omit results that are negative. For example, I frequently see claims in papers of the form: “XXX is up to 3.5x faster than YYY.” Such claims cherry-pick the best result to report and are misleading because they do not indicate what performance can be expected in the common case. Statements like this belong in late-night TV commercials, not scientific papers.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eBias, need to present well\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePerformance analysis is not an instantaneous process like taking a picture of a finished artwork. It is a long and drawn-out process of confusion, discovery, and improvement. Performance analysis goes through several phases, each of which can take anywhere from a few days to a few weeks. First, you must add instrumentation code to the system to record the desired metrics. You must then get benchmark applications running, either by writing them or by downloading and installing existing programs. Running benchmarks will probably stress the system enough to expose bugs, and you will need to then track down and fix them. Eventually, the system will run well enough to start producing performance numbers. However, these numbers will almost certainly be wrong. The next step is to find and fix bugs in the measurements. Once you have verified the accuracy of the measurements, you will start to uncover problems with the system itself. As you look over the performance measurements, you will probably uncover additional functional bugs. Once they have been fixed, you can start analyzing the performance in depth. You will almost certainly discover opportunities to improve performance, and it is important to have enough time to make these improvements. You will encounter many things that do not make sense; in order to resolve them, you will need to add new metrics and validate them. To get the best results, you must iterate several times improving the metrics, measuring performance, and improving the system.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat an example. Iterate iterate iterate\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eI often challenge them by asking: “Suppose I said I don’t believe these measurements. What can you say to convince me that they are correct?”\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAsk myself this\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAs you begin collecting measurements, compare them and be alert for inconsistencies. There will almost always be things that do not make sense. When something does not make complete sense, stop and gather more data. For example, in a recent measurement of a new network transport protocol, a benchmark indicated that a server could handle no more than 600,000 packets per second. However, my colleagues and I had seen servers process more than 900,000 packets per second with other protocols and believed the new protocol was at least as efficient as the old ones. We decided to gather additional data. As a result, we discovered a bug in the flow-control mechanism on the client side: clients were not transmitting data fast enough to keep the server fully loaded. Fixing the bug improved performance to the level we expected.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting, gather, but how to know what to do next and what data to filter? I guess that’s based on experience\u003c/p\u003e\n\n\u003ch5 id=\"keys-to-high-quality-performance-analysis\"\u003eKeys to High-Quality Performance Analysis\u003c/h5\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe first step toward high-quality performance measurements is to allow enough time. If you are measuring a non-trivial system, you should plan on at least two to three months.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThat’s interesting – this makes senses, but this takes a loooong time\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePerformance analysis is not an instantaneous process like taking a picture of a finished artwork. It is a long and drawn-out process of confusion, discovery, and improvement. Performance analysis goes through several phases, each of which can take anywhere from a few days to a few weeks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTake different measurements at the same level. For example, if you are measuring file-system throughput, do not measure just the throughput seen by a user application; also measure the throughput observed inside the operating system (such as at the file block cache). These measurements should match;\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMeasure the system’s behavior at a lower level to break down the factors that determine performance, as I discuss later under Rule 4 (Always measure one level deeper);\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMake back-of-the-envelope calculations to see if the measurements are in the ballpark expected; and\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eRun simulations and compare their results to measurements of the real implementation.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eDamn this is different steps. Always double check essentially\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAbove all, do not tolerate anything you do not understand.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat a thought.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAbove all, do not tolerate anything you do not understand. Assume there are bugs and problems with every measurement, and your job is to find and fix them. If you do not find problems, you should feel uneasy, because there are probably bugs you missed.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe best way to use intuition is to identify promising areas for further exploration. For example, when looking over performance measurements, ask yourself if they make sense. How does the performance compare to what you expected? Does it seem too good to be true? Does the system scale more poorly than you had hoped? Does a curve jump unexpectedly when you expected it to be smooth? Do some benchmarks exhibit behavior that is dramatically different from others? Consider anything that does not match your intuition a red flag and investigate it, as described in Rule 2 (Never trust a number generated by a computer). Intuition can be very helpful in identifying problems.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you continually form intuitions and then test them you will gain knowledge that helps you form better intuition in the future. Every false intuition means there was something you did not fully understand; in the process of testing it and discovering why it is false, you will learn something useful.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eIntuition is used as a guide for the first step\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf you are measuring overall latency for remote procedure calls, you could measure deeper by breaking down that latency, determining how much time is spent in the client machine, how much time is spent in the network, and how much time is spent on the server. You could also measure where time is spent on the client and server. If you are measuring the overall throughput of a system, the system probably consists of a pipeline containing several components. Measure the utilization of each component (the fraction of time that component is busy). At least one component should be 100% utilized; if not, it should be possible to achieve a higher throughput.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eLatency and throughput measurements in a single sentence?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIn recent measurements of a new network transport, one of my students found that round-trip tail latency was higher than our simulations had predicted. The student measured software latency in detail on both the sending and the receiving machines but found nothing that could account for the high tail latency. At this point we were about to conclude that the delays must be caused by the network switch. What else could it be? This would have been Mistake 2 (Guessing instead of measuring). Before giving up, we decided to dig deeper and measure precise timings for each individual packet. The measurements surprised us, showing that outlier delays were not isolated events. Delay tended to build up over a series of packets, affecting all of the packets from a single sender over a relatively long time interval, including packets for different destinations. This was a crucial clue. After several additional measurements, the student discovered that long queues were building up in the sender’s network interface due to a software bug. The transport included code to estimate the queue length and prevent queue buildup, but there was a bug in the estimator caused by underflow of an unsigned integer. The underflow was easy to fix, at which point tail latency dropped dramatically. Not only did this process improve the system’s performance, it taught us an important lesson about the risks of unsigned integers.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGood example\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnother way to measure deeper is to consider more detail. Instead of just looking at average values, graph the entire distribution and noodle over the shape to see if it provides useful information. Then look at some of the raw data samples to see if there are patterns. In one measurement of RPC latency, a student found that the average latency was higher than we expected. The latency was not intolerably high, and it would have been easy to simply accept this level of performance. Fortunately, the student decided to graph the times for individual RPCs. It turned out the data was bimodal, whereby every other RPC completed quickly, but the intervening ones were all significantly slower. With this information, the student tracked down and fixed a configuration error that eliminated all of the slow times. In this case, the average value was not a good indicator of system behavior.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSo basically always look at indivudal ones and keep measuring\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eDo not spend a lot of time agonizing over which deeper measurements to make. If the top-level measurements contain contradictions or things that are surprising, start with measurements that could help resolve them. Or pick measurements that will identify performance bottlenecks. If nothing else, choose a few metrics that are most obvious and easiest to collect, even if you are not sure they will be particularly illuminating. Once you look at the results, you will almost certainly find things that do not make sense; from this point on, track down and resolve everything that does not make perfect sense. Along the way you will discover other surprises; track them down as well. Over time, you will develop intuition about what kinds of deeper measurements are most likely to be fruitful.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI see, just go for it, use standard tools\u003c/p\u003e\n\n\u003ch5 id=\"measurement-infrastructure\"\u003eMeasurement Infrastructure\u003c/h5\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eMaking good performance measurements takes time, so it is worth creating infrastructure to help you work more efficiently. The infrastructure will easily pay for itself by the time the measurement project is finished. Furthermore, performance measurements tend to be run repeatedly, making infrastructure even more valuable. In a cloud service provider, for example, measurements must be made continuously in order to maintain contractual service levels. In a research project, the full suite of performance measurements will be run several times (such as before submission, after the paper is accepted, and again during the writing of a Ph.D. dissertation). It is important to have infrastructure that makes it easy to rerun tests.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes I see… this is how you learn how to build such infrastructure\u003c/p\u003e\n\n\u003ch4 id=\"summaryimportant-takeaways\"\u003eSummary/Important takeaways\u003c/h4\u003e\n\n\u003cul\u003e\n \u003cli\u003eDig deep into understanding performance\n \u003cul\u003e\n \u003cli\u003eThe question is how to do so (are you measuring the right thing and how to identify when you fucked up)\u003c/li\u003e\n \u003cli\u003eThis is a trained methodlogy (way of thinking to measure performance), which is not easy to be disciplined\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eMistakes to watch out for\n \u003cul\u003e\n \u003cli\u003eTrusting numbers immediately if the system is not crashing\n \u003cul\u003e\n \u003cli\u003eperformance bugs occur in non crashing conditions, thus are not immediately obvious\u003c/li\u003e\n \u003cli\u003eso the logical question is how do you prove that the numbers are trust-worthy?\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eGuessing (or making what seems obvious assumptions) without backing up the claims\n \u003cul\u003e\n \u003cli\u003eex, system is bottlenecked by I/O, well you need to show that it’s true with numbers, and maybe actually it isn’t bottlenecked by I/O, this is very true\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eOnly measuring end-2-end\n \u003cul\u003e\n \u003cli\u003eWhat would make it better? What’s taking the longest in the system?\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eIf you believe in the idea, you believe that the performnace will be good (confirmation bias) and not double checking that number\u003c/li\u003e\n \u003cli\u003eDon’t rush your numbers that you measure - easy to make mistakes\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eHow to not make mistakes\n \u003cul\u003e\n \u003cli\u003eTime\n \u003cul\u003e\n \u003cli\u003eNeed to build instrumentation, benchmarks, patch bugs, repeat\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eFind different ways to measure the same thing/Don’t trust the number\n \u003cul\u003e\n \u003cli\u003e“I often challenge them by asking: “Suppose I said I don’t believe these measurements. What can you say to convince me that they are correct?””\u003c/li\u003e\n \u003cli\u003eFor example, if you are measuring file-system throughput, do not measure just the throughput seen by a user application; also measure the throughput observed inside the operating system (such as at the file block cache). These measurements should match\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eUse your intuition to ask questions, not to answer them\n \u003cul\u003e\n \u003cli\u003eIt’s good to have a gut feeling to check something, but always verify that it’s true\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eAlways measure one level deeper to breakdown numbers\n \u003cul\u003e\n \u003cli\u003eex, e2e measure latency, can breakdown client, server, network time\u003c/li\u003e\n \u003cli\u003evalidate top level numbers\u003c/li\u003e\n \u003cli\u003euse your knowledge of known tools\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eMeasurement Infrastructure\n \u003cul\u003e\n \u003cli\u003eHow to build your set of tools to measure performance\u003c/li\u003e\n \u003cli\u003eWhat is good infrastructure\n \u003cul\u003e\n \u003cli\u003eAutomated, each run does the performance\u003c/li\u003e\n \u003cli\u003eEasy to digest/understand\u003c/li\u003e\n \u003cli\u003ebenchmarks to compare\u003c/li\u003e\n \u003cli\u003eDashboard\n \u003cul\u003e\n \u003cli\u003egoal: easy to understand!\u003c/li\u003e\n \u003cli\u003ebut brings together a lot of data\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-07-19/dashboard.png\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eDashboard example\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cul\u003e\n \u003cli\u003eGives a lot of information and breaking each one down with e2e, network, and internal software\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-07-19/figure2.jpg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eDashboard example\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cul\u003e\n \u003cli\u003eExample of how to expand and get a better understanding – it depends on the inputs\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/ramblings/2025-07-19/figure3.jpg\" width=\"100%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eDashboard example\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cul\u003e\n \u003cli\u003eExample of how to expand and get a better understanding – it depends on the inputs (this time, you have to split the x into equal parts)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch4 id=\"final-thoughts\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eThis is a very good read. Performance is something that you iterate on. It’s quite a process that’s simple on the surface: make assumptions, create benchmarks to verify that claim. But the reality is different:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003eMake infrastructure to benchmark\u003c/li\u003e\n \u003cli\u003ePerformance process\n \u003cul\u003e\n \u003cli\u003ethink of what to important variables to observe from the system (mostly throughput/latency)\u003c/li\u003e\n \u003cli\u003eback up with benchmark\n \u003cul\u003e\n \u003cli\u003ethe initial numbers - end to end numbers (process one request)\u003c/li\u003e\n \u003cli\u003ethe subnumbers (network/storage/processing)\u003c/li\u003e\n \u003cli\u003ecompare against other to check if the values are in appropriate range\u003c/li\u003e\n \u003cli\u003erepeat\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n\u003c/ul\u003e","summary":"Always Measure One Level Deeper","date_published":"2025-07-19T00:00:00+00:00","date_modified":"2025-07-19T00:00:00+00:00","author":{"name":""},"tags":["Rambling"]},{"id":"https://maknee.github.io/blog/2025/Paul-Graham-Why-Nerds-Are-Unpopular","url":"https://maknee.github.io/blog/2025/Paul-Graham-Why-Nerds-Are-Unpopular/","title":"Paul Graham - Why Nerds Are Unpopular","content_html":"\u003ch3 id=\"paul-graham---why-nerds-are-unpopular\"\u003ePaul Graham - Why Nerds Are Unpopular\u003c/h3\u003e\n\n\u003cp\u003eThoughts about \u003ca href=\"https://paulgraham.com/nerds.html\"\u003eWhy Nerds Are Unpopular\u003c/a\u003e by Paul Graham.\u003c/p\u003e\n\n\u003cp\u003eBefore we dive into this, this was written in 2003. This was when Paul Graham was 38, when he was not married or have kids.\u003c/p\u003e\n\n\u003cp\u003eThis is also a rather long essay…\u003c/p\u003e\n\n\u003ch4 id=\"thoughts-along-the-way\"\u003eThoughts along the way\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWe sat at a D table, as low as you could get without looking physically different. We were not being especially candid to grade ourselves as D. It would have taken a deliberate lie to say otherwise. Everyone in the school knew exactly how popular everyone else was, including us.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSeems relatable.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eI know a lot of people who were nerds in school, and they all tell the same story: there is a strong correlation between being smart and being a nerd, and an even stronger inverse correlation between being a nerd and being popular. Being smart seems to make you unpopular.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting – time investment is into becoming good at grades rather than appearance/people\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhy? To someone in school now, that may seem an odd question to ask. The mere fact is so overwhelming that it may seem strange to imagine that it could be any other way. But it could. Being smart doesn’t make you an outcast in elementary school. Nor does it harm you in the real world. Nor, as far as I can tell, is the problem so bad in most other countries. But in a typical American secondary school, being smart is likely to make your life difficult. Why?\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting… observation. Still true as you get older, not just in school\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIn the schools I went to, being smart just didn’t matter much. Kids didn’t admire it or despise it. All other things being equal, they would have preferred to be on the smart side of average rather than the dumb side, but intelligence counted far less than, say, physical appearance, charisma, or athletic ability.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes. That does not matter much to kids as it’s harder to read.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSo if intelligence in itself is not a factor in popularity, why are smart kids so consistently unpopular? The answer, I think, is that they don’t really want to be popular.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting, huh, people need attention in some way.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBut in fact I didn’t, not enough. There was something else I wanted more: to be smart. Not simply to do well in school, though that counted for something, but to design beautiful rockets, or to write well, or to understand how to program computers. In general, to make great things.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI guess so, but generally people want attention in some way, not so much to be popular…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAt the time I never tried to separate my wants and weigh them against one another. If I had, I would have seen that being smart was more important. If someone had offered me the chance to be the most popular kid in school, but only at the price of being of average intelligence (humor me here), I wouldn’t have taken it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI agree. But slightly. Being popular and knowing how to utilize it can benefit (sometimes more than) being smart\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnd that, I think, is the root of the problem. Nerds serve two masters. They want to be popular, certainly, but they want even more to be smart. And popularity is not something you can do in your spare time, not in the fiercely competitive environment of an American secondary school.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHaha, yeah – time investment.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eNerds don’t realize this. They don’t realize that it takes work to be popular. In general, people outside some very demanding field don’t realize the extent to which success depends on constant (though often unconscious) effort. For example, most people seem to consider the ability to draw as some kind of innate quality, like being tall. In fact, most people who “can draw” like drawing, and have spent many hours doing it; that’s why they’re good at it. Likewise, popular isn’t just something you are or you aren’t, but something you make yourself.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eAgreed. Did not realize this until very late. It takes a lot of time and thought and honestly, experimentation (+ failures) to become popular…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eEven if nerds cared as much as other kids about popularity, being popular would be more work for them. The popular kids learned to be popular, and to want to be popular, the same way the nerds learned to be smart, and to want to be smart: from their parents. While the nerds were being trained to get the right answers, the popular kids were being trained to please.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHaha, suprised I reached the same reasoning. Paul’s writing is good.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSo far I’ve been finessing the relationship between smart and nerd, using them as if they were interchangeable. In fact it’s only the context that makes them so. A nerd is someone who isn’t socially adept enough. But “enough” depends on where you are. In a typical American school, standards for coolness are so high (or at least, so specific) that you don’t have to be especially awkward to look awkward by comparison.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh god. Yes. It’s so very easy to seem awkward to someone, even becoming older. People tend to judge quickly, especially in the US.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePartly because teenagers are still half children, and many children are just intrinsically cruel. Some torture nerds for the same reason they pull the legs off spiders. Before you develop a conscience, torture is amusing.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHaha… yes, people don’t accept differences (from their own view of the world), especially if they’re children\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnother reason kids persecute nerds is to make themselves feel better. When you tread water, you lift yourself up by pushing water down. Likewise, in any social hierarchy, people unsure of their own position will try to emphasize it by maltreating those they think rank below. I’ve read that this is why poor whites in the United States are the group most hostile to blacks.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes… Definitely when I was a teenager. I see this to some extent, even now.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBecause they’re at the bottom of the scale, nerds are a safe target for the entire school. If I remember correctly, the most popular kids don’t persecute nerds; they don’t need to stoop to such things. Most of the persecution comes from kids lower down, the nervous middle classes.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh interesting – good observation. Happens when you’re older too, or maybe I just interpret some events like that.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAs well as gaining points by distancing oneself from unpopular kids, one loses points by being close to them. A woman I know says that in high school she liked nerds, but was afraid to be seen talking to them because the other girls would make fun of her. Unpopularity is a communicable disease; kids too nice to pick on nerds will still ostracize them in self-defense.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHaha…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIt’s important to realize that, no, the adults don’t know what the kids are doing to one another. They know, in the abstract, that kids are monstrously cruel to one another, just as we know in the abstract that people get tortured in poorer countries. But, like us, they don’t like to dwell on this depressing fact, and they don’t see evidence of specific abuses unless they go looking for it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI don’t think I understand it to that extent. Maybe I’ve forgotten.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003ePublic school teachers are in much the same position as prison wardens. Wardens’ main concern is to keep the prisoners on the premises. They also need to keep them fed, and as far as possible prevent them from killing one another. Beyond that, they want to have as little to do with the prisoners as possible, so they leave them to create whatever social organization they want. From what I’ve read, the society that the prisoners create is warped, savage, and pervasive, and it is no fun to be at the bottom of it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWow what a conclusion. I do agree with this. Again this is not PRIVATE school teachers – public school teachers have like 30-40 students to take care of per class. There’s easily not that much time devoted to each kid’s problems.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhen the things you do have real effects, it’s no longer enough just to be pleasing. It starts to be important to get the right answers, and that’s where nerds show to advantage. Bill Gates will of course come to mind. Though notoriously lacking in social skills, he gets the right answers, at least as measured in revenue.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHuh, yes. School is much more restricted in that sense.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIf I could go back and give my thirteen year old self some advice, the main thing I’d tell him would be to stick his head up and look around. I didn’t really grasp it at the time, but the whole world we lived in was as fake as a Twinkie. Not just school, but the entire town. Why do people move to suburbia? To have kids! So no wonder it seemed boring and sterile. The whole place was a giant nursery, an artificial town created explicitly for the purpose of breeding children.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGood advice – I’m going to take this advice.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat bothers me is not that the kids are kept in prisons, but that (a) they aren’t told about it, and (b) the prisons are run mostly by the inmates. Kids are sent off to spend six years memorizing meaningless facts in a world ruled by a caste of giants who run after an oblong brown ball, as if this were the most natural thing in the world. And if they balk at this surreal cocktail, they’re called misfits.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGlad I reached this conclusion when I was in school.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAdults can’t avoid seeing that teenage kids are tormented. So why don’t they do something about it? Because they blame it on puberty. The reason kids are so unhappy, adults tell themselves, is that monstrous new chemicals, hormones, are now coursing through their bloodstream and messing up everything. There’s nothing wrong with the system; it’s just inevitable that kids will be miserable at that age.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eBlaming on something that can’t be fully explained – Typical. Also, sometimes I fall into this habit, but I’ve stopped it mostly.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhen I was in school, suicide was a constant topic among the smarter kids. No one I knew did it, but several planned to, and some may have tried. Mostly this was just a pose. Like other teenagers, we loved the dramatic, and suicide seemed very dramatic. But partly it was because our lives were at times genuinely miserable.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTrue true true\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAt best it was practice for real work we might do far in the future, so far that we didn’t even know at the time what we were practicing for. More often it was just an arbitrary series of hoops to jump through, words without content designed mainly for testability. (The three main causes of the Civil War were…. Test: List the three main causes of the Civil War.)\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnd there was no way to opt out. The adults had agreed among themselves that this was to be the route to college. The only way to escape this empty life was to submit to it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eEven in adult life, with a “job”, you get these structured instances too…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTeenage kids used to have a more active role in society. In pre-industrial times, they were all apprentices of one sort or another, whether in shops or on farms or even on warships. They weren’t left to create their own societies. They were junior members of adult societies.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThat’s a good observation – most of the useful stuff I learned was outside of school – working with my father, exploring/navigating the city\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat happened? We’re up against a hard one here. The cause of this problem is the same as the cause of so many present ills: specialization. As jobs become more specialized, we have to train longer for them. Kids in pre-industrial times started working at about 14 at the latest; kids on farms, where most people lived, began far earlier. Now kids who go to college don’t start working full-time till 21 or 22. With some degrees, like MDs and PhDs, you may not finish your training till 30.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting thought. Yes, and it REQUIRES schooling again…\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe real problem is the emptiness of school life. We won’t see solutions till adults realize that. The adults who may realize it first are the ones who were themselves nerds in school. Do you want your kids to be as unhappy in eighth grade as you were? I wouldn’t. Well, then, is there anything we can do to fix things? Almost certainly. There is nothing inevitable about the current system. It has come about mostly by default.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eYes. Man, I was dumb for not realizing this soon…\u003c/p\u003e\n\n\u003ch4 id=\"final-thoughts\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eThis is one of Paul’s older essays. He rambles quite a bit. Each paragraph after the like 5th one repeats what he says, but with a different story or tone. I like the point of the essay. Nerds are UNPOPULAR, and the time of that unpopularity actually does drag on these days (to even past college) due to the internet and having these traits be embedded with culture beyond school.\u003c/p\u003e\n\n\u003cp\u003eOne thing I do disagree with Paul is that popularity does matter, just in a different sense. To be popular is something that most people are not well adjusted to, say being attractive for the first time, or being more well known on the internet and being able to respond in a social setting well. I believe that these early years in life builds that and allows one to experience that type of feeling – to build “confidence” in some way. Because this matters after the teenager years, and is a useful skill to have. However, to be popular, it’s hard, and most kids are just thrown into the battle grounds to figure it out. No one really teaches them.\u003c/p\u003e\n\n\u003cp\u003eI do agree with most of paul’s points on school. It’s a rigid structure that is basically a battleground for kids to bully another and place themselves into groups. Then you can pretend for the most part to pass classes if you put some effort and learn how to do so (I guess this is what is “smart”?). I wish kids did more apprentice-esque classes or etc, so that someone can show them some view of the adult world. I didn’t understand some until after college and am still learning.\u003c/p\u003e\n\n\u003cp\u003eBut why is Paul seem so harsh – angry almost? Does he regret going to such schools? Bitter? I can relate if so. I can’t really describe good things about school. Just hung out with the nerds, and that was fun, I think?\u003c/p\u003e","summary":"Paul Graham - Why Nerds Are Unpopular","date_published":"2025-06-29T00:00:00+00:00","date_modified":"2025-06-29T00:00:00+00:00","author":{"name":""},"tags":["Rambling"]},{"id":"https://maknee.github.io/blog/2025/3FS-Performance-Journal-2","url":"https://maknee.github.io/blog/2025/3FS-Performance-Journal-2/","title":"A Reality Check on DeepSeek’s Distributed File System Benchmarks","content_html":"\u003ch1 id=\"series\"\u003eSeries\u003c/h1\u003e\n\n\u003cul\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-1/\"\u003eAn Intro to DeepSeek’s Distributed File System\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-2/\"\u003eA Reality Check on DeepSeek’s Distributed File System Benchmarks\u003c/a\u003e\u003c/li\u003e\n \u003cli\u003e\u003ca href=\"/blog/2025/3FS-Performance-Journal-3/\"\u003eNetwork Storage and Scaling Characteristics of a Distributed Filesystem\u003c/a\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003c!--\n- [Theoretical Performance Limits of 3FS](/blog/2018/RTX-DXR-Path-Tracer-Host/)\n- [Benchmarking 3FS](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n- [Analysis of 3FS Benchmarks](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n- [Improving 3FS Performance](/blog/2018/RTX-DXR-Path-Tracer-HLSL/)\n--\u003e\n\n\u003ch1 id=\"how-should-we-analyze-3fs\"\u003eHow should we analyze 3FS?\u003c/h1\u003e\n\n\u003cp\u003eIn \u003ca href=\"/blog/2025/3FS-Performance-Journal-1/\"\u003emy previous post\u003c/a\u003e, I introduced DeepSeek’s \u003ca href=\"https://github.com/deepseek-ai/3FS/tree/ee9a5cee0a85c64f4797bf380257350ca1becd36\"\u003e3FS distributed file system\u003c/a\u003e – exploring its architecture, components, and the CRAQ protocol that provides its consistency guarantees. Now, I want to take a closer look at the published benchmark results and performance claims.\u003c/p\u003e\n\n\u003cp\u003eWhen evaluating distributed systems, there’s a tendency to jump straight into complex profiling tools and detailed metrics.\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eTrying out perf, strace for syscalls, iostat for disk, it’s essentially throwing random darts until you hit something\u003c/span\u003e However, I find tremendous value in performing an initial “performance reality check” on numbers and graphs. The check uses reference numbers from other sources, such as hardware manufacturer specifications or existing benchmarks, to provide a baseline quickly for a particular system\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eFor example, when I drive a car on the highway, I try to match the speed to the other cars around me. Without that reference, it might turn out that I’m over the speed limit if I’m not constantly checking the speed gauge\u003c/span\u003e. This approach helps identify potential bottlenecks or inconsistencies before deploying software tools for deeper investigation.\u003c/p\u003e\n\n\u003cp\u003eA “performance reality check” can reveal the following:\u003c/p\u003e\n\n\u003col\u003e\n \u003cli\u003eIt validates whether benchmark results match what we’d expect based on the hardware configuration\u003c/li\u003e\n \u003cli\u003eIt helps identify which components (network, storage, cpu, etc) represent the main bottleneck\u003c/li\u003e\n \u003cli\u003eIt reveals the percentage of theoretical capacity actually being utilized\u003c/li\u003e\n \u003cli\u003eIt verifies whether the authors’ claims are valid and how they may have arrived at those conclusions\u003c/li\u003e\n\u003c/ol\u003e\n\n\u003cp\u003eTo illustrate this method, imagine a startup making claims about their new database – “built for AI training” and “hyper performance”. They showcase a benchmark from a single node:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/example1.svg\" width=\"75%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eA company produces a graph showing the throughput of one of their machines running the workload\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe system managed to read 250 GB in the total time, which seems impressive! However, this is like saying I drove 100 miles without mentioning whether it took an hour or 10. The rate (GB per second) reveals the actual work accomplished relative to time invested. Let’s approximate it by drawing a slope through the data.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/example2.svg\" width=\"75%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eDrawing a line through the graph to get the rate\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003e2 GB/s. Great number, but one might wonder – what should we compare this number to?\u003c/p\u003e\n\n\u003cp\u003eA start might be to ask is if this utilizing the full potential of the hardware? Looking up \u003ca href=\"https://www.micron.com/content/dam/micron/global/public/documents/products/technical-marketing-brief/7450-nvme-ssd-tech-prod-spec.pdf\"\u003emodern SSD\u003c/a\u003e specifications for random reads and plotting that on the graph can reveal the following:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/example3.svg\" width=\"75%\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eTaking a different look at the graph with theoretical limits\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eTheoretically, the system should reach 500 GB by the end of the test period!\u003c/p\u003e\n\n\u003cp\u003eThe benchmark is only utilizing about half of the available device bandwidth. This might raise some eyebrows about their performance claims – where are the bottlenecks?\u003c/p\u003e\n\n\u003cp\u003eThis analytical approach is exactly what I’ll apply to DeepSeek’s 3FS benchmarks throughout this post. By calculating what the hardware should deliver and comparing it to what 3FS actually achieves, we can identify where the possible bottlenecks lie and assess whether performance claims hold up.\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWhile not exact, these comparisons give us immediate insights that would take days to obtain through benchmarking\u003c/span\u003e\u003c/p\u003e\n\n\u003ch2 id=\"into-analyzing-3fs\"\u003eInto analyzing 3FS\u003c/h2\u003e\n\n\u003cp\u003eI’ll be examining three different workloads that showcase 3FS in action:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAI training jobs featuring a massive amount of reads\u003c/li\u003e\n \u003cli\u003eGraySort, a synthetic sorting benchmark with a mix of reads and writes\u003c/li\u003e\n \u003cli\u003eKV cache in operation, representing LLM inference workloads with random reads\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eEach benchmark consists of two main components – client and storage. The client sends a request to read/write to the storage node over a network. Then, the storage node reads/writes the data to its device(s) and responds to the client by sending a message back. Thus, the two main hardware components one should analyze closely are network and storage.\u003c/p\u003e\n\n\u003cp\u003eFor each benchmark, I’ll break down the hardware configuration, calculate theoretical maximums, and analyze how close the system comes to achieving its potential performance. Through this analysis, we’ll develop intuition about 3FS’s real-world capabilities before even deploying it.\u003c/p\u003e\n\n\u003cp\u003eLet’s start by examining what could be 3FS’s most important benchmark: training throughput for AI workloads.\u003c/p\u003e\n\n\u003ch2 id=\"first-workload-training-job\"\u003eFirst workload: Training job\u003c/h2\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/peak_throughput.jpg\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003ePeak throughput for training jobs\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eA training workload typically involves GPU nodes acting as clients that read data (text, images, etc.) from storage nodes to train deep neural networks like language or diffusion models. The throughput hovers around 6.6 TB/s\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eIt’s not made explicit if this read throughput is the average or median. I would assume the average throughput.\u003c/span\u003e on average, with peak throughput reaching 8 TB/s as reported in the \u003ca href=\"https://arxiv.org/abs/2408.14158\"\u003eFire-Flyer AI-HPC paper\u003c/a\u003e.\u003c/p\u003e\n\n\u003cp\u003eHere’s the hardware configuration the benchmark uses:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Node Type,Component,Specification-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Node Type,Component,Specification\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Node Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Specification\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNumber of nodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e500\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1 × 200Gbps NIC\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStorage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNumber of nodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e180\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e16 × 14TB PCIe 4.0 NVMe SSDs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2 × 200Gbps NICs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMemory\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e512 GB DDR4-3200MHz\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row6-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row6-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eCPU\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row6-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2 × AMD 32 Cores EPYC Rome/Milan\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eLet’s apply the “performance reality check” on these numbers – Below are some back-of-the-envelope calculations\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003e\u003ca href=\"https://en.wikipedia.org/wiki/Back-of-the-envelope_calculation\"\u003eBack-of-the-envelope calculations\u003c/a\u003e involve performing rough additions and multiplications to get an approximate number within the range of the actual value\u003c/span\u003e to get an idea of the theoretical limits\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThe authors don’t list the SSD used in the benchmark, so I’ll be using a \u003ca href=\"https://www.micron.com/content/dam/micron/global/public/documents/products/technical-marketing-brief/7450-nvme-ssd-tech-prod-spec.pdf\"\u003eMicron 7450 15.36TB U.3 enterprise SSD\u003c/a\u003e as reference\u003c/span\u003e of the benchmark. Click the “Show calculations” toggle button in the top right to see the detailed breakdown. The base numbers (7GB/s, 4GB/s, 6GB/s, 2GB/s) come from reference SSD specifications I selected to represent this workload. Also, the NIC’s throughput is measured in Gbps instead of GB/s. Performing the conversion: 200Gbps ÷ 8 = 25GB/s.\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"performance-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003cdiv class=\"toggle-container\"\u003e\n \u003cdiv class=\"calc-toggle\"\u003e\n \u003cspan id=\"performance-table-toggle-text\" class=\"toggle-text\"\u003eShow calculations\u003c/span\u003e\n \u003cspan id=\"performance-table-toggle\" onclick=\"toggleCalculations('performance-table')\"\u003e\n \u003cspan class=\"toggle-switch\" id=\"performance-table-switch\"\u003e\u003c/span\u003e\n \u003cspan class=\"toggle-label\"\u003eToggle calculations\u003c/span\u003e\n \u003c/span\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \n \u003ctable id=\"performance-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Node Type\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Metric\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Per Unit\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Per Node\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Entire Cluster\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStorage (180)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Sequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e7 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e112 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e112 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e7 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e20.16 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e20.16 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e112 GB/s × 180\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Random Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e64 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e64 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e4 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row1-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e5.04 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e5.04 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e64 GB/s × 180\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Sequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e6 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e96 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e96 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e6 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row2-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e7.56 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e7.56 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e96 GB/s × 180\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Random Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row3-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e32 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e32 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e2 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row3-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e2.52 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e2.52 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e2 GB/s × 180\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row4-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e50 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e50 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e25 GB/s × 2\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row4-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e9 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e9 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e50 GB/s × 180\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient (500)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row5-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row5-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e12.5 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e12.5 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e25 GB/s × 500\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eML Training\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eAvg Read Throughput\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row6-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row6-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e6.6 TB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eML Training\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003ePeak Read Throughput\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row7-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"performance-table-row7-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e8 TB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eFrom these numbers, one can observe that:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003e\u003cspan data-highlight-cells=\"performance-table-row5-col4, performance-table-row4-col4\"\u003eThe client’s network will not be a bottleneck (\u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"rgba(255,255,0,0.2)\"\u003e12.5 TB/s client network throughput\u003c/span\u003e \u0026gt; \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"rgba(255,255,0,0.2)\"\u003e9 TB/s storage network throughput\u003c/span\u003e)\u003c/span\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eHover over the text to see the numbers highlighted in the table!\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003e\u003cspan data-highlight-cells=\"performance-table-row6-col4,performance-table-row1-col4\"\u003eThe training job workload indicates a mix of sequential/random read because 6.6 TB/s average throughput is greater than the maximum disk random read throughput (\u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e6.6 TB/s\u003c/span\u003e \u0026gt; \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e5 TB/s\u003c/span\u003e)\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003e\u003cspan data-highlight-cells=\"performance-table-row0-col4, performance-table-row4-col4\"\u003eThe storage nodes will be bottlenecked by network or disk depending on the type of workload. A network bottleneck occurs when the workload primarily consists of sequential reads\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eAn example of this type of workload is reading a large file (movie, song, etc) in order to transfer the data to another device\u003c/span\u003e and the network cannot match the sequential throughput (\u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e20 TB/s\u003c/span\u003e \u0026gt; \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e9 TB/s\u003c/span\u003e)\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003e\u003cspan data-highlight-cells=\"performance-table-row1-col4,performance-table-row2-col4,performance-table-row3-col4,performance-table-row4-col4\"\u003eWhen workload primarily consist random reads, sequential write, or random writes, the storage becomes the bottleneck rather than the network.\n (\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e5 TB/s\u003c/span\u003e, \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e7.5 TB/s\u003c/span\u003e, \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e2.5 TB/s\u003c/span\u003e \u0026lt; \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e9 TB/s\u003c/span\u003e\u003c/span\u003e)\u003c/li\u003e\n \u003cli\u003eThis workload is most likely bottlenecked by network bandwidth. The sequential read throughput can reach up to \u003cspan data-highlight-cells=\"performance-table-row0-col4\" data-hover-text-color=\"darkred\" data-hover-cell-bg=\"#FFFFE0\"\u003e20 TB/s\u003c/span\u003e and the network throughput is \u003cspan data-highlight-cells=\"performance-table-row4-col4\" data-hover-text-color=\"maroon\" data-hover-cell-bg=\"#FFFFE0\"\u003e9 TB/s\u003c/span\u003e, but the peak throughput of \u003cspan data-highlight-cells=\"performance-table-row7-col4\" data-hover-text-color=\"maroon\" data-hover-cell-bg=\"#FFFFE0\"\u003e8 TB/s\u003c/span\u003e and average throughput of \u003cspan data-highlight-cells=\"performance-table-row6-col4\" data-hover-text-color=\"maroon\" data-hover-cell-bg=\"#FFFFE0\"\u003e6.6 TB/s\u003c/span\u003e are below the network limit and well below the maximum sequential throughput.\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eSometimes it’s hard to look at numbers. If we replot the numbers relative to the maximum sequential throughput of a SSD and lay the numbers on a bar plot, we can get a better idea of where the numbers are:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/paper_throughput_relative_to_sequential_reads.svg\" width=\"100%\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThe visualization reveals some interesting insights about system utilization that we have already pointed out:\u003c/p\u003e\n\n\u003cul\u003e\n \u003cli\u003eThe average 6.6 TB/s represents:\n \u003cul\u003e\n \u003cli\u003e33% of theoretical sequential disk throughput (6.6 / 20 TB/s)\u003c/li\u003e\n \u003cli\u003e73% of available network bandwidth (6.6 / 9 TB/s)\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n \u003cli\u003eThe peak 8 TB/s achieves:\n \u003cul\u003e\n \u003cli\u003e40% of theoretical sequential disk throughput (8 / 20 TB/s)\u003c/li\u003e\n \u003cli\u003e88% of available network bandwidth (8 / 9 TB/s)\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe data clearly shows that network bandwidth becomes the primary bottleneck. This suggests two potential optimization paths: either remove half of the SSDs from each storage node or upgrade to 800 Gbps NICs to unlock full sequential read potential. However, implementing these changes presents practical challenges. Hardware platforms often have limitations that prevent NIC upgrades and removing storage may leave PCIe slots unused. And, pure cost alone may make changing the existing setup unreasonable.\u003c/p\u003e\n\n\u003cp\u003eAlso, why does peak throughput cap at 8 TB/s rather than closer to the theoretical network limit of 9 TB/s? Is this a fundamental software limitation, or does it reflect overhead associated with network operations\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eCould be TCP or RDMA overhead\u003c/span\u003e at this scale?\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI’ll have better answers to such questions when I run benchmarks on 3FS\u003c/span\u003e\u003c/p\u003e\n\n\u003ch3 id=\"revisiting-the-training-job-with-some-background\"\u003eRevisiting the training job with some background\u003c/h3\u003e\n\n\u003cp\u003eNow, let’s revisit the throughput over time graph with these background numbers in mind.\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/peak_throughput.jpg\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003ePeak throughput for training jobs\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe graph shows read throughput hovering around 6.6 TB/s, which represents approximately 73% of the theoretical 9 TB/s network capacity\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eI typically set 0 as the starting point for the y axis, which gives you an absolute base number that you can compare to\u003c/span\u003e. This leaves 27% of potential throughput unutilized, suggesting possible system bottlenecks such as:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eMetadata communication network overhead (think TCP headers)\u003c/li\u003e\n \u003cli\u003eNetwork completion delays before reading\u003c/li\u003e\n \u003cli\u003eWorkload imbalance creating hot nodes\u003c/li\u003e\n \u003cli\u003eFUSE bottlenecks in the client for file operations\u003c/li\u003e\n \u003cli\u003eKernel overhead in managing communication and disk I/O\u003c/li\u003e\n \u003cli\u003eStraggler storage nodes slowed by disk issues (temperature, wear)\u003c/li\u003e\n \u003cli\u003eNative filesystem (XFS, ext4) overheads\u003c/li\u003e\n \u003cli\u003e…\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch3 id=\"dips-in-the-training-job\"\u003eDips in the training job\u003c/h3\u003e\n\n\u003cp\u003eThe periodic dips in throughput are somewhat interesting:\u003c/p\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/paper_dips.svg\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n\u003c/div\u003e\n\n\u003cp\u003eThese dips could originate from either the filesystem or the workload itself. The filesystem might have internal mechanisms (periodic flushing, lock contention, etc.) that could cause these performance drops. But, because the dips occur at regular ~2.5 second intervals, it strongly suggests that checkpointing operations might cause these drops\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eThe dips hover around 6.3 TB/s, so at 6.6 TB/s average, that’s a 4.5% drop in throughput (300 GB/s / 6600 GB/s). These dips last roughly 10% of the time between peak points, so overall throughput may decrease by about 0.45% - not particularly significant.\u003c/span\u003e as the parts of the model may need to pause training while checkpoint data is written.\u003c/p\u003e\n\n\u003ch2 id=\"next-up-gray-sort-benchmark\"\u003eNext up: Gray Sort Benchmark\u003c/h2\u003e\n\n\u003ch3 id=\"what-is-gray-sort\"\u003eWhat is Gray Sort?\u003c/h3\u003e\n\n\u003cp\u003e\u003ca href=\"https://sortbenchmark.org/\"\u003eGray Sort\u003c/a\u003e is a synthetic benchmark that tests how quickly a system can sort large\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eLarge as in terabytes large, and definitely will not fit on a single machine\u003c/span\u003e amounts of data. The workload follows a specific pattern that stresses both sequential and random I/O operations:\u003c/p\u003e\n\n\u003col\u003e\n \u003cli\u003eRead unsorted data from storage into memory (sequential reads)\u003c/li\u003e\n \u003cli\u003eSort each data chunk in memory\u003c/li\u003e\n \u003cli\u003eWrite the sorted chunks back to storage (random-ish writes)\u003c/li\u003e\n \u003cli\u003eRead the fetching other node’s sorted chunks and merge them in memory (random-ish reads)\u003c/li\u003e\n \u003cli\u003eWrite the merged results back to disk (random-ish writes)\u003c/li\u003e\n \u003cli\u003eRepeat until all data is fully sorted\u003c/li\u003e\n \u003cli\u003eWrite the full sorted result to disk (sequential writes)\u003c/li\u003e\n\u003c/ol\u003e\n\n\u003cp\u003eThis alternating pattern of reads and writes, combined with both sort and merge phases, makes it a standard test for distributed filesystem performance\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eSounds like a \u003ca href=\"https://research.google/pubs/mapreduce-simplified-data-processing-on-large-clusters/\"\u003eMapReduce\u003c/a\u003e workload, essentially aggregating in keys in a range to a partition and then performing some operation on that range (merging in this case)\u003c/span\u003e.\u003c/p\u003e\n\n\u003ch3 id=\"initial-look-at-the-graphs\"\u003eInitial Look at the Graphs\u003c/h3\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/gray_sort_client.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eGray Sort Single Node Client Performance\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/gray_sort_server.png\" style=\"width: 105%; margin-left: calc((100% - 105%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eGray Sort Single Node Server Performance\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eNote that orange represents writes and blue represents reads.\u003c/p\u003e\n\n\u003cp\u003eLooking at the orange dotted lines separating the algorithm phases, there’s a clear pattern. The phase before the first dotted line is pure writes – the system writing the unsorted data to the storage. After that, we see mixed read/write workloads that gradually shift toward being more read-heavy as the sorting progresses\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eAs more and more sorted runs get merged together, there are fewer write operations needed since each merge pass consolidates multiple inputs into fewer outputs, while the read operations increase to fetch data from the remaining sorted runs. This pattern is observable when comparing workload differences between the 18:05:00-18:10:00 and 18:25:00-18:30:00 time ranges in the server throughput graph.\u003c/span\u003e\u003c/p\u003e\n\n\u003cp\u003eA few observations jump out:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eIf one were to eyeball the average combined (read / write) throughput per phase, it would hover around ~10-15 GB/s.\u003c/li\u003e\n \u003cli\u003eClients peak at around 10 GB/s for writes while peaking 22 GB/s for reads.\u003c/li\u003e\n \u003cli\u003eStorage nodes peak at 22 GB/s for writes and 30 GB/s for reads – their throughput is approximately twice the amount of the client’s average throughput, which makes sense given there are twice as many clients as storage nodes. We see this listed in the next section on hardware configuration.\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch3 id=\"hardware-configuration\"\u003eHardware Configuration\u003c/h3\u003e\n\n\u003cp\u003eFor this benchmark, 3FS was deployed with the following hardware setup:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"fancy-table-Node Type,Component,Specification-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-4 mb-4 overflow-x-auto\"\u003e\n \n \u003ctable id=\"fancy-table-Node Type,Component,Specification\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Node Type\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Component\n \u003c/th\u003e\n \n \u003cth class=\"px-6 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Specification\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNumber of nodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row0-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e50\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row1-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e1 × 200Gbps NIC\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eMemory\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row2-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2.2TB DDR4\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStorage\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNumber of nodes\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row3-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row4-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e16 × 14TB PCIe 4.0 NVMe SSDs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col1\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row5-col2\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2 × 400Gbps NICs\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"fancy-table-Node Type,Component,Specification-row6-col0\" class=\"px-6 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003ch3 id=\"analysis\"\u003eAnalysis\u003c/h3\u003e\n\n\u003cp\u003eThe main difference from the previous benchmark is that there are twice as many clients as there are storage nodes (compared to 3x from previous benchmark) and the storage nodes have twice as much network bandwidth!\u003c/p\u003e\n\n\u003cp\u003eLet’s calculate the theoretical performance limits for this configuration:\u003c/p\u003e\n\n\u003clink rel=\"stylesheet\" href=\"/assets/css/fancy_table.css\" /\u003e\n\n\u003cscript\u003e\nfunction toggleCalculations(tableId) {\n // ... (toggleCalculations function remains the same)\n const table = document.getElementById(tableId);\n const cells = table.querySelectorAll('.has-calculation');\n const tableWrapper = document.getElementById(tableId + '-wrapper');\n const toggleSwitch = document.getElementById(tableId + '-switch');\n const toggleText = document.getElementById(tableId + '-toggle-text');\n\n if (tableWrapper.classList.contains('show-calculations')) {\n tableWrapper.classList.remove('show-calculations');\n toggleSwitch.classList.remove('active');\n toggleText.textContent = \"Show calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'inline';\n calculationText.style.display = 'none';\n });\n } else {\n tableWrapper.classList.add('show-calculations');\n toggleSwitch.classList.add('active');\n toggleText.textContent = \"Hide calculations\";\n cells.forEach(cell =\u003e {\n const normalText = cell.querySelector('.normal-text');\n const calculationText = cell.querySelector('.calculation-text');\n normalText.style.display = 'none';\n calculationText.style.display = 'inline';\n });\n }\n}\n\nfunction isConsideredYellow(colorString) {\n if (!colorString) return false;\n const lowerColor = colorString.toLowerCase();\n\n const yellowKeywords = ['yellow', 'gold', 'lemon', 'chiffon', 'goldenrod', 'papayawhip', 'moccasin', 'khaki', '#ff0', '#ffff00', '#ffffe0', '#fffacd', '#fafad2', '#fff8dc', '#eee8aa', '#f0e68c'];\n if (yellowKeywords.some(k =\u003e lowerColor.includes(k))) {\n return true;\n }\n\n const match = lowerColor.match(/rgba?\\((\\d+),\\s*(\\d+),\\s*(\\d+)/);\n if (match) {\n const r = parseInt(match[1]);\n const g = parseInt(match[2]);\n const b = parseInt(match[3]);\n if (r \u003e 200 \u0026\u0026 g \u003e 180 \u0026\u0026 b \u003c 200 \u0026\u0026 Math.abs(r - g) \u003c 70) { \n return true;\n }\n }\n if (lowerColor === '#ffdab9') return true; // PeachPuff\n if (lowerColor === 'rgba(255,255,0,0.2)') return true; // Example: semi-transparent yellow\n return false;\n}\n\nfunction initializeCellHighlighters() {\n const highlighters = document.querySelectorAll('[data-highlight-cells]');\n\n highlighters.forEach(highlighter =\u003e {\n if (highlighter.dataset.highlighterInitialized === 'true') return;\n highlighter.dataset.highlighterInitialized = 'true';\n\n const cellIdsToHighlight = highlighter.dataset.highlightCells.split(',').map(id =\u003e id.trim());\n const cellElements = cellIdsToHighlight.map(id =\u003e document.getElementById(id)).filter(el =\u003e el);\n\n let descriptiveSpans = [];\n if (highlighter.matches('span[data-hover-text-color]')) {\n descriptiveSpans.push(highlighter);\n } else {\n descriptiveSpans = Array.from(highlighter.querySelectorAll('span[data-hover-text-color]'));\n }\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor === 'undefined') {\n span.dataset.originalParaTextColor = window.getComputedStyle(span).color || '';\n }\n });\n\n highlighter.addEventListener('mouseenter', () =\u003e {\n highlighter.classList.add('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (span.dataset.hoverTextColor) {\n span.style.color = span.dataset.hoverTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.add('cell-highlighted'); \n\n let hoverTextColorForCell = null;\n let hoverCellBgFromDescSpan = null;\n\n if (idx \u003c descriptiveSpans.length) {\n const descSpan = descriptiveSpans[idx];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n } else if (descriptiveSpans.length === 1 \u0026\u0026 cellElements.length \u003e 1) {\n const descSpan = descriptiveSpans[0];\n if (descSpan.dataset.hoverTextColor) {\n hoverTextColorForCell = descSpan.dataset.hoverTextColor;\n }\n if (descSpan.dataset.hoverCellBg) {\n hoverCellBgFromDescSpan = descSpan.dataset.hoverCellBg;\n }\n }\n\n\n if (hoverTextColorForCell) {\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n const calcWrapper = cell.querySelector('.calculation-text');\n if (calcWrapper \u0026\u0026 window.getComputedStyle(calcWrapper).display !== 'none') {\n if (calcResultSpan) calcResultSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n } else {\n if (normalTextSpan) normalTextSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.setProperty('color', hoverTextColorForCell, 'important');\n }\n }\n\n if (hoverCellBgFromDescSpan \u0026\u0026 isConsideredYellow(hoverCellBgFromDescSpan)) {\n if (typeof cell.dataset.originalBgColor === 'undefined') {\n cell.dataset.originalBgColor = cell.style.backgroundColor; \n }\n cell.style.backgroundColor = hoverCellBgFromDescSpan;\n cell.dataset.bgChangedByScript = 'true';\n }\n });\n });\n\n highlighter.addEventListener('mouseleave', () =\u003e {\n highlighter.classList.remove('trigger-text-active');\n\n descriptiveSpans.forEach(span =\u003e {\n if (typeof span.dataset.originalParaTextColor !== 'undefined') {\n span.style.color = span.dataset.originalParaTextColor;\n }\n });\n\n cellElements.forEach((cell, idx) =\u003e {\n cell.classList.remove('cell-highlighted');\n\n const isCalcCell = cell.classList.contains('has-calculation');\n if (isCalcCell) {\n const normalTextSpan = cell.querySelector('.normal-text');\n const calcResultSpan = cell.querySelector('.calc-result');\n if (normalTextSpan) normalTextSpan.style.removeProperty('color');\n if (calcResultSpan) calcResultSpan.style.removeProperty('color');\n } else {\n const directSpan = cell.querySelector(':scope \u003e span');\n if (directSpan) directSpan.style.removeProperty('color');\n }\n\n if (cell.dataset.bgChangedByScript === 'true') {\n cell.style.backgroundColor = cell.dataset.originalBgColor || '';\n delete cell.dataset.originalBgColor;\n delete cell.dataset.bgChangedByScript;\n }\n });\n });\n });\n}\n\nif (document.readyState === 'loading') {\n document.addEventListener('DOMContentLoaded', initializeCellHighlighters);\n} else {\n initializeCellHighlighters();\n}\n\u003c/script\u003e\n\n\u003cstyle\u003e\n/* ... (Other styles remain the same) ... */\n.toggle-container {\n display: flex;\n justify-content: flex-end;\n margin-bottom: 6px;\n}\n.calc-toggle {\n display: inline-flex;\n align-items: center;\n}\n.toggle-switch {\n position: relative;\n display: inline-block;\n width: 32px;\n height: 16px;\n background-color: #e5e7eb; /* gray-200 */\n border-radius: 10px;\n transition: all 0.3s;\n cursor: pointer;\n}\n.toggle-switch::after {\n content: '';\n position: absolute;\n width: 12px;\n height: 12px;\n border-radius: 50%;\n background-color: white;\n top: 2px;\n left: 2px;\n transition: all 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275);\n box-shadow: 0 1px 2px rgba(0, 0, 0, 0.1);\n}\n.toggle-switch.active {\n background-color: #8b5cf6; /* purple-500 */\n}\n.toggle-switch.active::after {\n left: 18px;\n}\n.toggle-label {\n position: absolute;\n width: 1px;\n height: 1px;\n padding: 0;\n margin: -1px;\n overflow: hidden;\n clip: rect(0, 0, 0, 0);\n white-space: nowrap;\n border-width: 0;\n}\n.toggle-text {\n font-size: 0.75rem;\n color: #6b7280; /* gray-500 */\n margin-right: 6px;\n}\n.calc-result {\n color: rgb(75, 85, 99); /* gray-700 */\n}\n.calc-formula {\n color: #6b83a6; /* Subtle bluish gray */\n font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;\n}\n.calc-equals {\n color: rgb(75, 85, 99); /* gray-700 */\n display: inline-block;\n margin: 0 4px;\n font-weight: normal;\n}\n\n.cell-highlighted {\n font-weight: bold !important; \n}\n\n.trigger-text-active {\n background-color: #E5E7EB; \n padding: 1px 3px; \n border-radius: 3px; \n transition: background-color 0.15s ease-in-out;\n}\n.trigger-text-active span[data-hover-text-color] {\n padding: 0; \n background-color: transparent; \n}\n\n/* Add styles for no-scroll option */\n.table-wrapper-no-scroll {\n overflow-x: visible !important;\n}\n\n.table-no-scroll td {\n white-space: normal !important;\n word-break: break-word;\n}\n\u003c/style\u003e\n\n\u003cdiv id=\"graysort-table-wrapper\" class=\"px-4 rounded-lg __basic-table not-prose mt-2 mb-2 overflow-x-auto\"\u003e\n \n \u003cdiv class=\"toggle-container\"\u003e\n \u003cdiv class=\"calc-toggle\"\u003e\n \u003cspan id=\"graysort-table-toggle-text\" class=\"toggle-text\"\u003eShow calculations\u003c/span\u003e\n \u003cspan id=\"graysort-table-toggle\" onclick=\"toggleCalculations('graysort-table')\"\u003e\n \u003cspan class=\"toggle-switch\" id=\"graysort-table-switch\"\u003e\u003c/span\u003e\n \u003cspan class=\"toggle-label\"\u003eToggle calculations\u003c/span\u003e\n \u003c/span\u003e\n \u003c/div\u003e\n \u003c/div\u003e\n \n \u003ctable id=\"graysort-table\" class=\"min-w-full divide-y divide-gray-200 font-sans basic-table-striped\"\u003e\n \u003cthead class=\"bg-gray-50\"\u003e\n \u003ctr\u003e\n \n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Node Type\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Metric\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Per Unit\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Per Node\n \u003c/th\u003e\n \n \u003cth class=\"px-4 py-2 text-left text-xs font-medium text-gray-500 tracking-wider\"\u003e\n Entire Cluster\n \u003c/th\u003e\n \n \u003c/tr\u003e\n \u003c/thead\u003e\n \u003ctbody\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row0-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eStorage (25)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row0-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Sequential Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row0-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e7 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row0-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e112 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e112 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e7 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row0-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e2.8 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e2.8 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e112 GB/s × 25\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row1-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row1-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Random Read\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row1-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e4 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row1-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e64 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e64 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e4 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row1-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e1.6 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e1.6 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e64 GB/s × 25\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row2-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row2-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Sequential Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row2-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e6 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row2-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e96 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e96 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e6 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row2-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e2.4 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e2.4 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e96 GB/s × 25\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row3-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row3-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eDisk - Random Write\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row3-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e2 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row3-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e32 GB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e32 GB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e2 GB/s × 16\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row3-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e0.8 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e0.8 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e32 GB/s × 25\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row4-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row4-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row4-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e100 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row4-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e100 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row4-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e2.5 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e2.5 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e100 GB/s × 25\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row5-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient (50)\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row5-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eNetwork\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row5-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row5-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e25 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row5-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600 has-calculation\"\u003e\n \u003cspan class=\"normal-text\"\u003e1.25 TB/s\u003c/span\u003e\n \u003cspan class=\"calculation-text\" style=\"display: none;\"\u003e\n \u003cspan class=\"calc-result\"\u003e1.25 TB/s\u003c/span\u003e\n \u003cspan class=\"calc-equals\"\u003e=\u003c/span\u003e\n \u003cspan class=\"calc-formula\"\u003e25 GB/s × 50\u003c/span\u003e\n \u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row6-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGray Sort\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row6-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient Write Peak\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row6-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row6-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~10 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row6-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row7-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGray Sort\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row7-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eClient Read Peak\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row7-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row7-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~22 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row7-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row8-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGray Sort\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row8-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eServer Write Peak\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row8-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row8-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~22 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row8-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row9-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eGray Sort\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row9-col1\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eServer Read Peak\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row9-col2\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row9-col3\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e~30 GB/s\u003c/span\u003e\n \u003c/td\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row9-col4\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003eN/A\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \n \n \u003ctr class=\"border-b border-gray-200\"\u003e\n \n \n \n \n \n \n \u003ctd id=\"graysort-table-row10-col0\" class=\"px-4 py-2 whitespace-nowrap text-sm font-medium text-gray-600\"\u003e\n \u003cspan style=\"color: rgb(75, 85, 99) !important;\"\u003e\u003c/span\u003e\n \u003c/td\u003e\n \n \n \u003c/tr\u003e\n \n \n \u003c/tbody\u003e\n \u003c/table\u003e\n\u003c/div\u003e\n\n\u003cp\u003eThe performance numbers reveal an interesting pattern. In the first phase, the server write peak achieves \u003cspan data-highlight-cells=\"graysort-table-row8-col3, graysort-table-row3-col3\"\u003e\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e~22 GB/s\u003c/span\u003e out of \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e32 GB/s\u003c/span\u003e random write capacity\u003c/span\u003e (69% utilization). In the second phase, the server read peak reaches \u003cspan data-highlight-cells=\"graysort-table-row9-col3, graysort-table-row1-col3\"\u003e\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e~30 GB/s\u003c/span\u003e out of \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e64 GB/s\u003c/span\u003e random read capacity\u003c/span\u003e (47% utilization), which is quite a bit lower than the relative utilization for writes. However, \u003cspan data-highlight-cells=\"graysort-table-row9-col3, graysort-table-row0-col3\"\u003ecomparing to sequential read throughput \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e~30 GB/s\u003c/span\u003e vs \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e112 GB/s\u003c/span\u003e\u003c/span\u003e (27% utilization) strongly signals that the workload is predominantly random rather than sequential.\u003c/p\u003e\n\n\u003cp\u003eLet’s take a look at the numbers:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eStorage nodes peak at \u003cspan data-highlight-cells=\"graysort-table-row8-col3, graysort-table-row9-col3, graysort-table-row4-col3\"\u003e\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e22 GB/s writes and 30 GB/s reads\u003c/span\u003e, well below the \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e100 GB/s network capacity\u003c/span\u003e\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eClient read peak achieves \u003cspan data-highlight-cells=\"graysort-table-row7-col3, graysort-table-row5-col2\"\u003e88% of network capacity (\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e22 GB/s\u003c/span\u003e out of \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e25 GB/s\u003c/span\u003e)\u003c/span\u003e\u003c/li\u003e\n \u003cli\u003eClient write peak hits only \u003cspan data-highlight-cells=\"graysort-table-row8-col3, graysort-table-row5-col2\"\u003e40% of network capacity (\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e10 GB/s\u003c/span\u003e out of \u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e25 GB/s\u003c/span\u003e)\u003c/span\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eWhy does the writes not peak nearly as high as reads? A reason might be from CRAQ’s consistency guarantees - each write must traverse the entire chain (head → middle → tail → back), which makes performance predictable unlike reads. Reads can come from the follower, or cause a consistency check at the tail\u003c/span\u003e\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe bottleneck here is clearly the number of clients. With the storage nodes far from saturated, we could support more clients. How many? If we want to saturate the storage sequential write capacity of \u003cspan data-highlight-cells=\"graysort-table-row2-col4, graysort-table-row5-col2\"\u003e\u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e2.4 TB/s\u003c/span\u003e and each client can push \u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e25 GB/s\u003c/span\u003e\u003c/span\u003e:\u003c/p\u003e\n\n\u003cp\u003e2.4 TB/s ÷ 25 GB/s = ~96 clients\u003c/p\u003e\n\n\u003cp\u003eNearly double the current 50 clients! This suggests the current configuration may be significantly underutilizing the storage infrastructure.\u003c/p\u003e\n\n\u003cp\u003eInterestingly, \u003cspan data-highlight-cells=\"graysort-table-row8-col3, graysort-table-row6-col3\"\u003ethe storage write peak (\u003cspan data-hover-text-color=\"rgba(80, 150, 100, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e22 GB/s\u003c/span\u003e) slightly exceeds client write peak (\u003cspan data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e20 = 2 × 10 GB/s\u003c/span\u003e)\u003c/span\u003e. With 50 clients at 10 GB/s distributed across 25 storage nodes, each node should see ~20 GB/s, with the extra 2 GB/s coming somewhere – maybe, from CRAQ protocol overhead?\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eCRAQ requires writes to propagate through chains, potentially creating additional write traffic beyond what clients generate\u003c/span\u003e\u003c/p\u003e\n\n\u003cp\u003eThe end-to-end performance measurements, however, reveal an unexpected pattern: the \u003ca href=\"https://github.com/deepseek-ai/3FS/tree/ee9a5cee0a85c64f4797bf380257350ca1becd36\"\u003ebenchmark notes mention achieving 3.66 TB/min\u003c/a\u003e – 61 GB/s aggregate throughput, which doesn’t sound too bad. But that’s just 4.88% of the \u003cspan data-highlight-cells=\"graysort-table-row5-col4\" data-hover-text-color=\"rgba(52, 152, 219, 0.9)\" data-hover-cell-bg=\"#FFFFE0\"\u003e1.25 TB/s client network capacity\u003c/span\u003e. This suggests that most of bottleneck might not be network or disk at all – it could be even be CPU/memory bound from the sorting computation itself.\u003c/p\u003e\n\n\u003ch2 id=\"caching-the-key-value-pairs-of-the-transformer\"\u003eCaching the key-value pairs of the transformer\u003c/h2\u003e\n\n\u003ch3 id=\"what-is-the-kv-cache\"\u003eWhat is the KV Cache?\u003c/h3\u003e\n\n\u003cp\u003eThe KV cache stores the key-value pairs from attention mechanisms during LLM inference. Instead of recomputing these values for every new token, the system caches them to dramatically reduce computational overhead by trading computation for storage. For models like DeepSeek’s R1, this cache becomes substantial – each token requires approximately 70KB of storage in FP16 format.\u003c/p\u003e\n\n\u003cp\u003eThis workload represents an important real-world use case for 3FS. As LLMs process longer contexts and serve more users concurrently, the storage system must handle both massive reads (loading cached values) and periodic deletions (garbage collecting expired entries).\u003c/p\u003e\n\n\u003ch3 id=\"initial-look-at-the-graphs-1\"\u003eInitial Look at the Graphs\u003c/h3\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/kvcache_read_throughput.png\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eKV Cache Read Throughput\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/kvcache_gc_iops.png\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eKV Cache GC IOPS\n \n (Image source: \u003ca href=\"https://github.com/deepseek-ai/3FS\" rel=\"external nofollow noopener\" target=\"_blank\"\u003e3FS github\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n\u003cp\u003eThe graphs show per-client performance for KV cache operations. Looking at the read throughput graph:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAverage throughput hovers around 3 GB/s\u003c/li\u003e\n \u003cli\u003ePeak throughput reaches approximately 40 GB/s\u003c/li\u003e\n \u003cli\u003eWhich is more than 13x difference between average and peak\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThe GC IOPS graph reveals:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003ePeriodic bursts of deletion operations reaching 1-1.4M IOPS\u003c/li\u003e\n \u003cli\u003e~4 bursts per 5-minute interval\n \u003cul\u003e\n \u003cli\u003eLasts around ~40 seconds each, followed by similar periods of low activity\u003c/li\u003e\n \u003c/ul\u003e\n \u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eUnfortunately, the authors don’t specify the complete hardware configuration - we only know each client has a 400 Gbps NIC (50 GB/s). This means the peak 40 GB/s achieves 80% network utilization, while average performance uses only 6% of available bandwidth.\u003c/p\u003e\n\n\u003ch3 id=\"analyzing-the-workload\"\u003eAnalyzing the Workload\u003c/h3\u003e\n\n\u003cp\u003eThe read pattern is fundamentally random – individual KV entries are scattered across storage. However, each 70KB entry spans multiple 4KB blocks\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eSSDs read data in fixed-size blocks, typically 4KB. A 70KB entry requires reading 18 consecutive blocks\u003c/span\u003e, resulting in sequential device-level reads despite the random access pattern per entry.\u003c/p\u003e\n\n\u003cp\u003eLet me calculate what these throughput numbers mean for token processing:\u003c/p\u003e\n\n\u003cdetails style=\"font-size: 1.2em; margin: 1em 0;\"\u003e\n \u003csummary style=\"cursor: pointer; font-weight: bold;\"\u003eExpand for calculations for KV cache entry\u003c/summary\u003e\n\n \u003cp\u003e\u003ca href=\"https://github.com/deepseek-ai/DeepSeek-V3/blob/4c2fdb8f55e049553b9f4f1a3241f86d739c8cf8/inference/configs/config_671B.json\"\u003e671B configuration\u003c/a\u003e\u003c/p\u003e\n \u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e{\n \"vocab_size\": 129280,\n \"dim\": 7168,\n \"inter_dim\": 18432,\n \"moe_inter_dim\": 2048,\n \"n_layers\": 61,\n \"n_dense_layers\": 3,\n \"n_heads\": 128,\n \"n_routed_experts\": 256,\n \"n_shared_experts\": 1,\n \"n_activated_experts\": 8,\n \"n_expert_groups\": 8,\n \"n_limited_groups\": 4,\n \"route_scale\": 2.5,\n \"score_func\": \"sigmoid\",\n \"q_lora_rank\": 1536,\n \"kv_lora_rank\": 512,\n \"qk_nope_head_dim\": 128,\n \"qk_rope_head_dim\": 64,\n \"v_head_dim\": 128,\n \"dtype\": \"fp8\"\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e \u003c/div\u003e\n\n \u003cdiv class=\"image-container\" style=\"max-width: 100%; overflow: visible;\"\u003e\n \u003cimg loading=\"lazy\" src=\"/assets/images/posts/2025-03-13/part2/paper_mla.png\" style=\"width: 125%; margin-left: calc((100% - 125%) / 2);\" alt=\"\" /\u003e\n \n \n \u003cdiv class=\"caption\"\u003e\n \u003cem\u003eKV Cache MLA calculation described in Deepseek V2\n \n (Image source: \u003ca href=\"https://arxiv.org/pdf/2405.04434\" rel=\"external nofollow noopener\" target=\"_blank\"\u003eDeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model\u003c/a\u003e)\n \n \u003c/em\u003e\n \u003c/div\u003e\n \n\u003c/div\u003e\n\n \u003cp\u003eGiven:\u003c/p\u003e\n \u003cul\u003e\n \u003cli\u003ekv_lora_rank = 512\u003c/li\u003e\n \u003cli\u003eqk_rope_head_dim = 64\u003c/li\u003e\n \u003cli\u003en_layers = 61\u003c/li\u003e\n \u003c/ul\u003e\n\n \u003cp\u003ePer token: (512 + 64) × 61 = 35,136 elements\u003c/p\u003e\n\n \u003cp\u003eIn FP16 format (2 bytes per element) = 70,272 bytes ≈ 70KB per token\nIn FP8 format (1 byte per element) = 35,136 bytes ≈ 35KB per token\u003c/p\u003e\n\n\u003c/details\u003e\n\n\u003cp\u003eWith 70KB per token:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAverage throughput (3 GB/s) processes ~43,000 tokens/second per client\u003c/li\u003e\n \u003cli\u003ePeak throughput (40 GB/s) processes ~570,000 tokens/second per client\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eGiven R1’s 128K context length:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eAverage: Can read entire context in 3 seconds (128K ÷ 43K)\u003c/li\u003e\n \u003cli\u003ePeak: Can read entire context in 0.22 seconds (128K ÷ 570K)\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eThese numbers are impressive, but without knowing the number of concurrent users or typical context lengths, it’s hard to judge real-world performance.\u003c/p\u003e\n\n\u003ch3 id=\"alignment-concerns\"\u003eAlignment Concerns\u003c/h3\u003e\n\n\u003cp\u003eHere’s an issue the authors don’t address: alignment waste. Modern NVMe SSDs use 4KB block sizes, but KV cache entries are 70KB – not cleanly divisible.\u003c/p\u003e\n\n\u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003eBlocks needed = ⌈70,272 ÷ 4,096⌉ = 18 blocks\nActual storage = 18 × 4,096 = 73,728 bytes\nWasted space = 3,456 bytes (4.69%)\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e\n\n\u003cp\u003eThis 4.69% waste might seem small, but at scale it adds up. With enterprise SSDs costing ~$2,200 each:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eCost per SSD: $103\u003c/li\u003e\n \u003cli\u003eCost per node (16 SSDs): ~$1,650\u003c/li\u003e\n \u003cli\u003eCost per 180 nodes: ~$297,000\u003c/li\u003e\n \u003cli\u003eCost per 10,000 nodes: ~$16,500,000\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003cp\u003eFor a company running thousands of clusters, this alignment inefficiency could waste millions in storage costs.\u003c/p\u003e\n\n\u003ch3 id=\"garbage-collection\"\u003eGarbage Collection\u003c/h3\u003e\n\n\u003cp\u003eThe GC algorithm isn’t detailed, but entries likely get marked for deletion when no longer referenced. The deletion mechanism remains unclear - could involve bit flags, pointer updates, zeroing entries, or \u003ca href=\"https://en.wikipedia.org/wiki/Log-structured_merge-tree#Operations\"\u003etombstone markers\u003c/a\u003e.\u003c/p\u003e\n\n\u003cp\u003eThe periodic burst pattern (1-1.4M IOPS) suggests that it’s probably more efficient to threshold-based eviction or batch processing rather than continuous cleanup for this type of workload. While throughput remains stable during GC, these spikes could impact performance if disks are already near throughput capacity\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eGarbage collection problems have appeared numerous times in many existing systems – showing up as \u003ca href=\"https://github.com/facebook/rocksdb/issues/3972\"\u003ecompaction issues in RocksDB\u003c/a\u003e or \u003ca href=\"https://stackoverflow.com/questions/54831212/postgresql-autovacuum-causing-significant-performance-degradation\"\u003eauto vacuum spikes in Postgres\u003c/a\u003e\u003c/span\u003e.\u003c/p\u003e\n\n\u003ch3 id=\"remaining-feedback\"\u003eRemaining feedback\u003c/h3\u003e\n\n\u003cp\u003eSome critical information is absent from this benchmark, most notably the lack of latency graphs. For LLM serving, latency matters as much as throughput - users need consistent time-to-first-token and smooth text generation, or they’ll switch to another service (chatgpt, gemini, claude, etc…).\u003c/p\u003e\n\n\u003cp\u003eSomeone at Deepseek clearly knows how to configure systems well if this is a real sample from a live system. The 80% peak utilization indicates a well-configured system with just enough headroom.\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003eNobody wants that 3am call to discuss needing to setup more machines to handle the traffic\u003c/span\u003e\u003c/p\u003e\n\n\u003ch1 id=\"closing-thoughts\"\u003eClosing Thoughts\u003c/h1\u003e\n\n\u003cp\u003eThe benchmarks focus exclusively on throughput, omitting latency metrics entirely. Not sure why they skipped latency – perhaps cost considerations took priority. While latency optimization is notoriously difficult\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003e\u003ca href=\"http://www.stuartcheshire.org/rants/latency.html\"\u003eStuart Cheshire: “It’s the latency, stupid”\u003c/a\u003e\u003c/span\u003e\u003cspan class=\"sidenote-ref\"\u003e\u003c/span\u003e\u003cspan class=\"sidenote\"\u003e\u003ca href=\"https://www.barroso.org/publications/TheTailAtScale.pdf\"\u003eJeff Dean on tail latencies at scale\u003c/a\u003e\u003c/span\u003e, my future evaluations will include latency measurements and explore optimizations to improve the latency.\u003c/p\u003e\n\n\u003cp\u003eDespite these limitations and critiques, the benchmarks align well with theoretical calculations and provide valuable insights into 3FS performance at scale.\u003c/p\u003e\n\n\u003cp\u003eIn upcoming posts, I’ll benchmark 3FS myself to verify these graph/claims and dig deeper:\u003c/p\u003e\n\u003cul\u003e\n \u003cli\u003eTesting actual hardware limits vs theoretical calculations\u003c/li\u003e\n \u003cli\u003eMeasuring latency distributions, not just throughput\u003c/li\u003e\n \u003cli\u003eCreating custom visualizations for storage and network performance patterns\u003c/li\u003e\n \u003cli\u003eValidating if our back-of-the-envelope math holds up\u003c/li\u003e\n \u003cli\u003eProfiling with various tools (perf, sampling, adapting source code) to identify bottlenecks\u003c/li\u003e\n\u003c/ul\u003e\n\n\u003ch1 id=\"acknowledgments\"\u003eAcknowledgments\u003c/h1\u003e\n\n\u003cp\u003eThanks to \u003ca href=\"https://sbaziotis.com/\"\u003eStefanos Baziotis\u003c/a\u003e, \u003ca href=\"https://www.linkedin.com/in/ahan-gupta-405619103/\"\u003eAhan Gupta\u003c/a\u003e, and \u003ca href=\"https://vimarsh.me/\"\u003eVimarsh Sathia\u003c/a\u003e for reviewing this post.\u003c/p\u003e\n\n\u003ch1 id=\"citation\"\u003eCitation\u003c/h1\u003e\n\n\u003cp\u003eTo cite this article:\u003c/p\u003e\n\n\u003cdiv class=\"language-plaintext highlighter-rouge\"\u003e\u003cdiv class=\"highlight\"\u003e\u003cpre class=\"highlight\"\u003e\u003ccode\u003e@article{zhu20253fs2,\n title = {A Reality Check on DeepSeek's Distributed File System Benchmarks},\n author = {Zhu, Henry},\n journal = {maknee.github.io},\n year = {2025},\n month = {June},\n url = \"https://maknee.github.io/blog/2025/3FS-Performance-Journal-2/\"\n}\n\u003c/code\u003e\u003c/pre\u003e\u003c/div\u003e\u003c/div\u003e","summary":"Series","date_published":"2025-06-18T09:00:00+00:00","date_modified":"2025-06-18T09:00:00+00:00","author":{"name":""},"tags":["3FS"]},{"id":"https://maknee.github.io/blog/2025/Paul-Graham-What-To-Do","url":"https://maknee.github.io/blog/2025/Paul-Graham-What-To-Do/","title":"Paul Graham - What to Do","content_html":"\u003ch3 id=\"paul-graham---what-to-do\"\u003ePaul Graham - What to Do\u003c/h3\u003e\n\n\u003cp\u003eThoughts about \u003ca href=\"https://paulgraham.com/do.html\"\u003eWhen To Do\u003c/a\u003e by Paul Graham.\u003c/p\u003e\n\n\u003ch4 id=\"thoughts-along-the-way\"\u003eThoughts along the way\u003c/h4\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat should one do? That may seem a strange question, but it’s not meaningless or unanswerable. It’s the sort of question kids ask before they learn not to ask big questions.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis statement about kids kind of took me off guard - I do see it happen (at least in myself). Why though? Does it see in his children and the kids that he encounters? What does he consider “kids” in this context - elemetary school, high school, college? I see this explained in the hierarchy of societies. Most definitely in military.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eI only came across it myself in the process of investigating something else. But once I did, I thought I should at least try to answer it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh, I haven’t explained why I was caught off guard. Because I haven’t thought about this in a long time. And I don’t have an answer yet.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eSo what should one do? One should help people, and take care of the world. Those two are obvious.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis are how kids would answer.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBut is there anything else? When I ask that, the answer that pops up is Make good new things.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat good things? How do you know that they are good? Or new?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe most impressive thing humans can do is to think. It may be the most impressive thing that can be done. And the best kind of thinking, or more precisely the best proof that one has thought well, is to make good new things.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI believe in this and he has state it out well with very concise sentences. I like it.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eNewton’s physics was a good new thing.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eSuprised by an example. The concept may be very abstract without a general example (here, where everyone knows about this discovery).\u003c/p\u003e\n\n\u003cp\u003eI’m going to guess that this discovery allowed people to develop technology (ships, safety, etc)?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIndeed, the first version of this principle was to have good new ideas. But that didn’t seem general enough: it didn’t include making art or music, for example, except insofar as they embody new ideas. And while they may embody new ideas, that’s not all they embody, unless you stretch the word “idea” so uselessly thin that it includes everything that goes through your nervous system.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eI don’t understand this very well; I think he’s trying to explain how general a new idea can be - I don’t think it has to be. It’s very very very very very very very difficult to make a general good new idea. I believe that it’s built upon the ideas of many people, hundreds, thousands, millions, etc… to get to a general good new idea. I see this repeated.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eTo make discoveries, for example, or to understand something more deeply than others have. But how well do you understand something if you can’t make a model of it, or write about it? Indeed, trying to express what you understand is not just a way to prove that you understand it, but a way to understand it better.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eEach time I do this, the more I believe in it.\u003c/p\u003e\n\n\u003cp\u003eI think I’ve applied it to a teeny bit of my life. And I hope that the same rule will apply in other aspects of life/experiences.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eAnother reason I like this phrasing is that it biases us toward creation. It causes us to prefer the kind of ideas that are naturally seen as making things rather than, say, making critical observations about things other people have made. Those are ideas too, and sometimes valuable ones, but it’s easy to trick oneself into believing they’re more valuable than they are.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTwo parts to this.\u003c/p\u003e\n\n\u003cp\u003eI don’t agree with phrasing it biasing towards creation. Seems forced - I didn’t see it that originally. Discoveries (albeit repeated among different individuals), can fall under this term. I believe that it’s more about thinking and learning.\u003c/p\u003e\n\n\u003cp\u003eYes, I agree with what Paul states about the observations. Even an intellectual person that you look up to may make a wrong guess. For example, the \u003ca href=\"https://en.wikipedia.org/wiki/Tanenbaum%E2%80%93Torvalds_debate\"\u003egodfather of operating systems\u003c/a\u003e lost in a debate against Linus that Linux would succeed as a monolothic kernel. Imagine that, a random ass college kid (Linus was 23 at the time) tells the most well known/accomplished professor in operating system at that time that his hobby operating system would win. If I were to be a random person in this flame war, I would have definitely not chose Linus’ arguments.\u003c/p\u003e\n\n\u003cp\u003eAnd I see this often in my life as well. People make observations all the time, but when some X action happens, they’re wrong sometimes. Should you believe their observations? Sometimes.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eCriticism seems sophisticated, and making new things often seems awkward, especially at first; and yet it’s precisely those first steps that are most rare and valuable.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThis next statement came out a bit off from the previous sentences. And I do think it’s necessary to have this statement. I think that the observations may seem most rare/valuable, but I believe that it’s a series of observations generally, and it takes a bit to form thoughts about different/unusual observations.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIs newness essential? I think so. Obviously it’s essential in science. If you copied a paper of someone else’s and published it as your own, it would seem not merely unimpressive but dishonest.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eInteresting statement about papers.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhich in turn implies it’s not impressive to make the same thing over and over, however well; you’re just copying yourself.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eThe problem here is that there’s not much learning (which is going through the problems and pain of steps to get to the end) - which I think Paul is stating here.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eHistorically most rules about how to live have been a mix of both kinds of should, though usually with more of the former than the latter.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eNice observation\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eArchimedes knew that he was the first to prove that a sphere has 2/3 the volume of the smallest enclosing cylinder and was very pleased about it. But you don’t find ancient writers urging their readers to emulate him. They regarded him more as a prodigy than a model.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eVery interesting observation. Why not emulate him?\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eNow many more of us can follow Archimedes’s example and devote most of our attention to one kind of work.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eOh.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eWhat kinds of new things count? I’d rather leave that question to the makers of them.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eHe didn’t answer the question… :(, but this is the answer to give.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eIt would be a risky business to try to define any kind of threshold, because new kinds of work are often despised at first. Raymond Chandler was writing literal pulp fiction, and he’s now recognized as one of the best writers of the twentieth century. Indeed this pattern is so common that you can use it as a recipe: if you’re excited about some kind of work that’s not considered prestigious and you can explain what everyone else is overlooking about it, then this is not merely a kind of work that’s ok to do, but one to seek out.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eWhat a good statement next. I’m focusing on the “hey it’s not good at first” part. I’ve seen this a couple of times already. But I think Paul doesn’t mention the other factors: time taken, mental stress, comfort, physical taxation, … are minor or major hurdles of going down such a route. It is sometimes brutal to go down such a path.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eThe kind of people who make good new things don’t need rules to keep them honest.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eTrue, but again, hurdles and this includes other people this time around.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eBut even if you’re one of those, you should at least make sure that the new things you make don’t net harm people or the world.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eVery hard to see sometimes.\u003c/p\u003e\n\n\u003cblockquote\u003e\n \u003cp\u003eOn the other hand, if you make something amazing, you’ll often be helping people or the world even if you didn’t mean to. Newton was driven by curiosity and ambition, not by any practical effect his work might have, and yet the practical effect of his work has been enormous. And this seems the rule rather than the exception. So if you think you can make something amazing, you should probably just go ahead and do it.\u003c/p\u003e\n\u003c/blockquote\u003e\n\n\u003cp\u003eGreat ending. “Just do it” - easier said than done. That’s for sure. Another thing that Paul doesn’t mention. It’s like the gym. It takes reps to build muscle. It takes reps to make something amazing. “Just do it” - yes, but make sure you see your goals clearly in the moment and can learn to/be able to identify plateaus. Not every person in the olympics just randomly just went at it and did one thing and became good at their sport.\u003c/p\u003e\n\n\u003ch4 id=\"final-thoughts\"\u003eFinal thoughts\u003c/h4\u003e\n\n\u003cp\u003eTo answer this generically - I can’t. But for doing stuff that you’re interested in and try to create something: talking to people, reading/watching the literature, doing something and then thinking, or guessing, doing something and then thinking are some ways one can get to a point of creating something or at the point of creating something. However, going through it may not be fun at times, may be actually uninteresting at times, or even depressing/cause someone to re-evaluate a lot (lost). I think that as long one hold the belief at one’s core, one will make progress. Ask anyone that one thinks is successful what they failed at/when they felt lost, they should answer with an event that sticks out or a couple or even mention that they wanted to give up.\u003c/p\u003e","summary":"Paul Graham - What to Do","date_published":"2025-04-20T00:00:00+00:00","date_modified":"2025-04-20T00:00:00+00:00","author":{"name":""},"tags":["Rambling"]}]}