ggml webgpu: shader library organization (#19530)
* Basic JIT compilation for mul_mat, get_rows, and scale (#17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * flashattention and matrix multiplication moved to new format * clean up preprocessing * Formatting * remove duplicate constants * Split large shaders into multiple static strings --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com>
This commit is contained in:
parent
ea003229d3
commit
238856ec8f
11 changed files with 1682 additions and 2018 deletions
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -1,5 +1,4 @@
|
|||
#decl(BYTE_HELPERS)
|
||||
|
||||
#ifdef BYTE_HELPERS
|
||||
fn get_byte(value: u32, index: u32) -> u32 {
|
||||
return (value >> (index * 8)) & 0xFF;
|
||||
}
|
||||
|
|
@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 {
|
|||
fn get_byte_i32(value: u32, index: u32) -> i32 {
|
||||
return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(BYTE_HELPERS)
|
||||
|
||||
#decl(Q4_0_T)
|
||||
#ifdef Q4_0_T
|
||||
struct q4_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q4_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1_T)
|
||||
#ifdef Q4_1_T
|
||||
struct q4_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q4_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0_T)
|
||||
#ifdef Q5_0_T
|
||||
struct q5_0 {
|
||||
d: f16,
|
||||
qh: array<f16, 2>,
|
||||
qs: array<f16, 8>
|
||||
};
|
||||
#enddecl(Q5_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_1_T)
|
||||
#ifdef Q5_1_T
|
||||
struct q5_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qh: u32,
|
||||
qs: array<u32, 4>
|
||||
};
|
||||
#enddecl(Q5_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0_T)
|
||||
#ifdef Q8_0_T
|
||||
struct q8_0 {
|
||||
d: f16,
|
||||
qs: array<f16, 16>
|
||||
};
|
||||
#enddecl(Q8_0_T)
|
||||
#endif
|
||||
|
||||
#decl(Q8_1_T)
|
||||
#ifdef Q8_1_T
|
||||
struct q8_1 {
|
||||
d: f16,
|
||||
m: f16,
|
||||
qs: array<u32, 8>
|
||||
};
|
||||
#enddecl(Q8_1_T)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K_T)
|
||||
struct q2_k {
|
||||
#ifdef Q2_K_T
|
||||
struct q2_K {
|
||||
scales: array<u32, 4>,
|
||||
qs: array<u32, 16>,
|
||||
d: f16,
|
||||
dmin: f16
|
||||
};
|
||||
#enddecl(Q2_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q3_K_T)
|
||||
struct q3_k {
|
||||
#ifdef Q3_K_T
|
||||
struct q3_K {
|
||||
hmask: array<f16, 16>,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 6>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q3_K_T)
|
||||
|
||||
#decl(Q45_K_SCALE_MIN)
|
||||
#endif
|
||||
|
||||
#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN)
|
||||
fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
||||
if (is < 4) {
|
||||
let sc_byte = get_byte(scales[is / 4], is % 4);
|
||||
|
|
@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> {
|
|||
return vec2(f32(sc), f32(m));
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(Q45_K_SCALE_MIN)
|
||||
|
||||
#decl(Q4_K_T)
|
||||
struct q4_k {
|
||||
#endif
|
||||
#ifdef Q4_K_T
|
||||
struct q4_K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q4_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q5_K_T)
|
||||
struct q5_k {
|
||||
#ifdef Q5_K_T
|
||||
struct q5_K {
|
||||
d: f16,
|
||||
dmin: f16,
|
||||
scales: array<u32, 3>,
|
||||
qh: array<u32, 8>,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(Q5_K_T)
|
||||
#endif
|
||||
|
||||
#decl(Q6_K_T)
|
||||
struct q6_k {
|
||||
#ifdef Q6_K_T
|
||||
struct q6_K {
|
||||
ql: array<f16, 64>,
|
||||
qh: array<f16, 32>,
|
||||
scales: array<f16, 8>,
|
||||
d: f16
|
||||
};
|
||||
#enddecl(Q6_K_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XXS_T)
|
||||
#ifdef IQ2_XXS_T
|
||||
struct iq2_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>
|
||||
};
|
||||
#enddecl(IQ2_XXS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS_T)
|
||||
#ifdef IQ2_XS_T
|
||||
struct iq2_xs {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_XS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S_T)
|
||||
#ifdef IQ2_S_T
|
||||
struct iq2_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
qh: array<f16, 4>,
|
||||
scales: array<f16, 4>
|
||||
};
|
||||
#enddecl(IQ2_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_XSS_T)
|
||||
#ifdef IQ3_XXS_T
|
||||
struct iq3_xxs {
|
||||
d: f16,
|
||||
qs: array<f16, 48>
|
||||
};
|
||||
#enddecl(IQ3_XSS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_S_T)
|
||||
#ifdef IQ3_S_T
|
||||
struct iq3_s {
|
||||
d: f16,
|
||||
qs: array<f16, 32>,
|
||||
|
|
@ -161,41 +156,41 @@ struct iq3_s {
|
|||
signs: array<f16, 16>,
|
||||
scales: array<f16, 2>
|
||||
};
|
||||
#enddecl(IQ3_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S_T)
|
||||
#ifdef IQ1_S_T
|
||||
struct iq1_s {
|
||||
d: f16,
|
||||
qs: array<f16, 16>,
|
||||
qh: array<f16, 8>
|
||||
};
|
||||
#enddecl(IQ1_S_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_M_T)
|
||||
#ifdef IQ1_M_T
|
||||
struct iq1_m {
|
||||
qs: array<u32, 8>,
|
||||
qh: array<u32, 4>,
|
||||
scales: array<u32, 2>
|
||||
};
|
||||
#enddecl(IQ1_M_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_NL_T)
|
||||
#ifdef IQ4_NL_T
|
||||
struct iq4_nl {
|
||||
d: f16,
|
||||
qs: array<f16, 8>,
|
||||
};
|
||||
#enddecl(IQ4_NL_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_XS_T)
|
||||
#ifdef IQ4_XS_T
|
||||
struct iq4_xs {
|
||||
d: f16,
|
||||
scales_h: f16,
|
||||
scales_l: u32,
|
||||
qs: array<u32, 32>
|
||||
};
|
||||
#enddecl(IQ4_XS_T)
|
||||
#endif
|
||||
|
||||
#decl(IQ23_TABLES)
|
||||
#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES)
|
||||
const kmask_iq2xs : array<u32, 2> = array<u32, 2>(
|
||||
0x08040201u, // 1, 2, 4, 8
|
||||
0x80402010u // 16, 32, 64, 128
|
||||
|
|
@ -211,9 +206,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>(
|
|||
0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c,
|
||||
0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc
|
||||
);
|
||||
#enddecl(IQ23_TABLES)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XXS_GRID)
|
||||
#ifdef IQ2_XXS_GRID
|
||||
const iq2xxs_grid = array<u32, 512>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808,
|
||||
|
|
@ -280,9 +275,9 @@ const iq2xxs_grid = array<u32, 512>(
|
|||
0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819,
|
||||
0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19
|
||||
);
|
||||
#enddecl(IQ2_XXS_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS_GRID)
|
||||
#ifdef IQ2_XS_GRID
|
||||
const iq2xs_grid = array<u32, 1024>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
|
|
@ -413,9 +408,9 @@ const iq2xs_grid = array<u32, 1024>(
|
|||
0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19,
|
||||
0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_XS_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S_GRID)
|
||||
#ifdef IQ2_S_GRID
|
||||
const iq2s_grid = array<u32, 2048>(
|
||||
0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808,
|
||||
0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808,
|
||||
|
|
@ -674,10 +669,9 @@ const iq2s_grid = array<u32, 2048>(
|
|||
0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b,
|
||||
0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b
|
||||
);
|
||||
#enddecl(IQ2_S_GRID)
|
||||
|
||||
#decl(IQ3_XSS_GRID)
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_XXS_GRID
|
||||
const iq3xxs_grid = array<u32, 256>(
|
||||
0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
|
||||
0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
|
||||
|
|
@ -712,10 +706,9 @@ const iq3xxs_grid = array<u32, 256>(
|
|||
0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
|
||||
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04
|
||||
);
|
||||
#enddecl(IQ3_XSS_GRID)
|
||||
|
||||
#decl(IQ3_S_GRID)
|
||||
#endif
|
||||
|
||||
#ifdef IQ3_S_GRID
|
||||
const iq3s_grid = array<u32, 512>(
|
||||
0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305,
|
||||
0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905,
|
||||
|
|
@ -782,9 +775,9 @@ const iq3s_grid = array<u32, 512>(
|
|||
0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b,
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101
|
||||
);
|
||||
#enddecl(IQ3_S_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_GRID)
|
||||
#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID)
|
||||
|
||||
const IQ1_DELTA: f32 = 0.125;
|
||||
|
||||
|
|
@ -919,12 +912,12 @@ const iq1_grid = array<u32, 1024>(
|
|||
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
|
||||
);
|
||||
|
||||
#enddecl(IQ1_GRID)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_GRID)
|
||||
#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID)
|
||||
|
||||
const kvalues_iq4nl = array<i32, 16>(
|
||||
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113
|
||||
);
|
||||
|
||||
#enddecl(IQ4_GRID)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -56,12 +56,46 @@ def expand_includes(shader, input_dir):
|
|||
return include_pattern.sub(replacer, shader)
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile):
|
||||
def chunk_shader(shader_code, max_chunk_len=60000):
|
||||
"""Split shader_code into safe raw-string sized chunks."""
|
||||
return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)]
|
||||
|
||||
|
||||
def raw_delim(shader_code):
|
||||
"""Pick a raw-string delimiter that does not appear in the shader."""
|
||||
delim = "wgsl"
|
||||
while f"){delim}\"" in shader_code:
|
||||
delim += "_x"
|
||||
return delim
|
||||
|
||||
|
||||
def write_shader(shader_name, shader_code, output_dir, outfile, input_dir):
|
||||
shader_code = expand_includes(shader_code, input_dir)
|
||||
|
||||
if output_dir:
|
||||
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
|
||||
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
|
||||
f_out.write(shader_code)
|
||||
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
|
||||
|
||||
delim = raw_delim(shader_code)
|
||||
chunks = chunk_shader(shader_code)
|
||||
|
||||
if len(chunks) == 1:
|
||||
outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n')
|
||||
else:
|
||||
for idx, chunk in enumerate(chunks):
|
||||
outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n')
|
||||
outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n')
|
||||
outfile.write(' static const std::string s = []{\n')
|
||||
outfile.write(' std::string tmp;\n')
|
||||
outfile.write(f' tmp.reserve({len(shader_code)});\n')
|
||||
for idx in range(len(chunks)):
|
||||
outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n')
|
||||
outfile.write(' return tmp;\n')
|
||||
outfile.write(' }();\n')
|
||||
outfile.write(' return s;\n')
|
||||
outfile.write('}\n')
|
||||
outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n')
|
||||
|
||||
|
||||
def generate_variants(fname, input_dir, output_dir, outfile):
|
||||
|
|
@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
try:
|
||||
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
|
||||
except ValueError:
|
||||
write_shader(shader_base_name, text, output_dir, outfile)
|
||||
write_shader(shader_base_name, text, output_dir, outfile, input_dir)
|
||||
else:
|
||||
try:
|
||||
decls_map = parse_decls(extract_block(text, "DECLS"))
|
||||
|
|
@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile):
|
|||
output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
|
||||
else:
|
||||
output_name = shader_base_name
|
||||
write_shader(output_name, final_shader, output_dir, outfile)
|
||||
write_shader(output_name, final_shader, output_dir, outfile, input_dir)
|
||||
|
||||
|
||||
def main():
|
||||
|
|
@ -137,7 +171,8 @@ def main():
|
|||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
with open(args.output_file, "w", encoding="utf-8") as out:
|
||||
out.write("// Auto-generated shader embedding\n\n")
|
||||
out.write("// Auto-generated shader embedding\n")
|
||||
out.write("#include <string>\n\n")
|
||||
for fname in sorted(os.listdir(args.input_dir)):
|
||||
if fname.endswith(".wgsl"):
|
||||
generate_variants(fname, args.input_dir, args.output_dir, out)
|
||||
|
|
|
|||
|
|
@ -1,222 +1,31 @@
|
|||
#define(VARIANTS)
|
||||
enable f16;
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_vec",
|
||||
"REPLS": {
|
||||
"TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"BLOCK_SIZE": 4
|
||||
},
|
||||
"DECLS": ["F32_VEC"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["F16"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "i32",
|
||||
"DST_TYPE": "i32",
|
||||
"BLOCK_SIZE": 1
|
||||
},
|
||||
"DECLS": ["I32"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_1",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q8_0",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q2_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q3_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q4_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q5_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "q6_k",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "iq2_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq2_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_xxs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq3_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_s",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq1_m",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_nl",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE": "iq4_xs",
|
||||
"DST_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(F32_VEC)
|
||||
#ifdef F32_VEC
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset];
|
||||
}
|
||||
#enddecl(F32_VEC)
|
||||
#endif
|
||||
|
||||
#decl(F32)
|
||||
#ifdef F32
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(F32)
|
||||
#endif
|
||||
|
||||
#decl(F16)
|
||||
#ifdef F16
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = f32(src[src_base + offset]);
|
||||
}
|
||||
#enddecl(F16)
|
||||
#endif
|
||||
|
||||
#decl(I32)
|
||||
#ifdef I32
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
dst[dst_base + offset] = src[src_base + offset];
|
||||
}
|
||||
#enddecl(I32)
|
||||
#endif
|
||||
|
||||
#decl(Q4_0)
|
||||
#ifdef Q4_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_0 = src[src_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
|
|
@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_0)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1)
|
||||
#ifdef Q4_1
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q4_1 = src[src_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
|
|
@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_1)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0)
|
||||
#ifdef Q5_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_0 = src[src_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
|
|
@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q5_0)
|
||||
|
||||
#decl(Q5_1)
|
||||
#ifdef Q5_1
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q5_1 = src[src_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
|
|
@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_1)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0)
|
||||
#ifdef Q8_0
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block_q8_0 = src[src_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
|
|
@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q8_0)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K)
|
||||
#ifdef Q2_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q2_K)
|
||||
#endif
|
||||
|
||||
#decl(Q3_K)
|
||||
#ifdef Q3_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q3_K)
|
||||
#endif
|
||||
|
||||
#decl(Q4_K)
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
|
@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q4_K)
|
||||
#endif
|
||||
|
||||
#decl(Q5_K)
|
||||
#ifdef Q5_K
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(Q5_K)
|
||||
#endif
|
||||
|
||||
#decl(Q6_K)
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
|
@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
sc_b_idx += 8;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q6_K)
|
||||
|
||||
#decl(IQ2_XXS)
|
||||
#ifdef IQ2_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XXS)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_XS)
|
||||
#ifdef IQ2_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ2_XS)
|
||||
#endif
|
||||
|
||||
#decl(IQ2_S)
|
||||
#ifdef IQ2_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_S)
|
||||
|
||||
#decl(IQ3_XSS)
|
||||
#ifdef IQ3_XXS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_XSS)
|
||||
#endif
|
||||
|
||||
#decl(IQ3_S)
|
||||
#ifdef IQ3_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#enddecl(IQ3_S)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S)
|
||||
#ifdef IQ1_S
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_S)
|
||||
|
||||
#decl(IQ1_M)
|
||||
#ifdef IQ1_M
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
|
||||
|
|
@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_M)
|
||||
|
||||
#decl(IQ4_NL)
|
||||
#ifdef IQ4_NL
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
dst_i++;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_NL)
|
||||
#endif
|
||||
|
||||
#decl(IQ4_XS)
|
||||
#ifdef IQ4_XS
|
||||
fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
||||
let block = src[src_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) {
|
|||
dst_i += 16;
|
||||
}
|
||||
}
|
||||
#enddecl(IQ4_XS)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<{{TYPE}}>;
|
||||
var<storage, read_write> src: array<SRC_TYPE>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> idx: array<i32>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{DST_TYPE}}>;
|
||||
var<storage, read_write> dst: array<DST_TYPE>;
|
||||
|
||||
struct Params {
|
||||
offset_src: u32, // in elements
|
||||
|
|
@ -842,8 +638,7 @@ struct Params {
|
|||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
|
||||
return;
|
||||
|
|
@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
|
||||
let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
|
||||
|
||||
for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) {
|
||||
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
|
||||
copy_elements(i_src_row, i_dst_row, i);
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,195 +1,24 @@
|
|||
#define(VARIANTS)
|
||||
enable f16;
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"BLOCK_SIZE" : 1
|
||||
},
|
||||
"DECLS" : ["FLOAT"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_1",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_1",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q8_0",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q2_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q3_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q4_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q5_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "q6_k",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_xxs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_xs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq2_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq3_xxs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq3_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq1_s",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq1_m",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq4_nl",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 32,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"]
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"SRC0_TYPE": "iq4_xs",
|
||||
"SRC1_TYPE": "f32",
|
||||
"BLOCK_SIZE": 256,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"]
|
||||
}
|
||||
]
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#end(VARIANTS)
|
||||
#ifdef FLOAT
|
||||
const BLOCK_SIZE = 1u;
|
||||
|
||||
#define(DECLS)
|
||||
#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL)
|
||||
const BLOCK_SIZE = 32u;
|
||||
|
||||
#decl(FLOAT)
|
||||
#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS)
|
||||
const BLOCK_SIZE = 256u;
|
||||
#endif
|
||||
|
||||
#ifdef FLOAT
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]);
|
||||
}
|
||||
#enddecl(FLOAT)
|
||||
#endif
|
||||
|
||||
#decl(Q4_0)
|
||||
#ifdef Q4_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_0.d);
|
||||
|
|
@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q4_0)
|
||||
#endif
|
||||
|
||||
#decl(Q4_1)
|
||||
#ifdef Q4_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q4_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q4_1.d);
|
||||
|
|
@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q4_1)
|
||||
#endif
|
||||
|
||||
#decl(Q5_0)
|
||||
#ifdef Q5_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_0.d);
|
||||
|
|
@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q5_0)
|
||||
#endif
|
||||
|
||||
#decl(Q5_1)
|
||||
#ifdef Q5_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q5_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q5_1.d);
|
||||
|
|
@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q5_1)
|
||||
#endif
|
||||
|
||||
#decl(Q8_0)
|
||||
#ifdef Q8_0
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_0 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_0.d);
|
||||
|
|
@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q8_0)
|
||||
#endif
|
||||
|
||||
#decl(Q8_1)
|
||||
#ifdef Q8_1
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block_q8_1 = src0[src0_idx_base + offset];
|
||||
let d = f32(block_q8_1.d);
|
||||
|
|
@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(Q8_1)
|
||||
#endif
|
||||
|
||||
#decl(Q2_K)
|
||||
#ifdef Q2_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q2_K)
|
||||
|
||||
#decl(Q3_K)
|
||||
#ifdef Q3_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q3_K)
|
||||
|
||||
#decl(Q4_K)
|
||||
#ifdef Q4_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q4_K)
|
||||
|
||||
#decl(Q5_K)
|
||||
#ifdef Q5_K
|
||||
// 8 blocks of 32 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q5_K)
|
||||
|
||||
#decl(Q6_K)
|
||||
#ifdef Q6_K
|
||||
// 16 blocks of 16 elements each
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
|
@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(Q6_K)
|
||||
|
||||
#decl(IQ2_XXS)
|
||||
#ifdef IQ2_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_XXS)
|
||||
|
||||
#decl(IQ2_XS)
|
||||
#ifdef IQ2_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ2_XS)
|
||||
|
||||
#decl(IQ2_S)
|
||||
#ifdef IQ2_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#enddecl(IQ2_S)
|
||||
|
||||
#decl(IQ3_XSS)
|
||||
#ifdef IQ3_XXS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ3_XSS)
|
||||
|
||||
#decl(IQ3_S)
|
||||
#ifdef IQ3_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#enddecl(IQ3_S)
|
||||
#endif
|
||||
|
||||
#decl(IQ1_S)
|
||||
#ifdef IQ1_S
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_S)
|
||||
|
||||
#decl(IQ1_M)
|
||||
#ifdef IQ1_M
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
|
||||
|
|
@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ1_M)
|
||||
|
||||
#decl(IQ4_NL)
|
||||
#ifdef IQ4_NL
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(IQ4_NL)
|
||||
|
||||
#decl(IQ4_XS)
|
||||
#ifdef IQ4_XS
|
||||
fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
||||
let block = src0[src0_idx_base + offset];
|
||||
let d = f32(block.d);
|
||||
|
|
@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 {
|
|||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
#enddecl(IQ4_XS)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32, // in elements/blocks
|
||||
|
|
@ -864,8 +672,8 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
|
@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|||
let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11;
|
||||
|
||||
var sum = 0.0;
|
||||
for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) {
|
||||
for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) {
|
||||
sum += multiply_add(src0_idx_base, src1_idx_base, i);
|
||||
}
|
||||
dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum;
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,58 +1,65 @@
|
|||
#decl(SHMEM_VEC)
|
||||
#ifdef VEC
|
||||
#define VEC_SIZE 4
|
||||
#define SHMEM_TYPE vec4<f16>
|
||||
#define DST_TYPE vec4<f32>
|
||||
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
fn store_shmem(val: vec4<f16>, idx: u32) {
|
||||
shmem[idx] = val.x;
|
||||
shmem[idx + 1] = val.y;
|
||||
shmem[idx + 2] = val.z;
|
||||
shmem[idx + 3] = val.w;
|
||||
}
|
||||
#enddecl(SHMEM_VEC)
|
||||
#endif
|
||||
|
||||
#ifdef SCALAR
|
||||
#define VEC_SIZE 1
|
||||
#define SHMEM_TYPE f16
|
||||
#define DST_TYPE f32
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
#decl(SHMEM_SCALAR)
|
||||
fn store_shmem(val: f16, idx: u32) {
|
||||
shmem[idx] = val;
|
||||
}
|
||||
#enddecl(SHMEM_SCALAR)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_FLOAT)
|
||||
#endif
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_FLOAT
|
||||
fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let tile_m = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_m = offset_m + tile_m;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src0_idx = batch_offset + global_m * params.stride_01 + global_k;
|
||||
let src0_val = select( // taking a slight performance hit to avoid oob
|
||||
{{SRC0_TYPE}}(0.0),
|
||||
src0[src0_idx/{{VEC_SIZE}}],
|
||||
SRC0_TYPE(0.0),
|
||||
src0[src0_idx/VEC_SIZE],
|
||||
global_m < params.m && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx);
|
||||
store_shmem(SHMEM_TYPE(src0_val), elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_FLOAT)
|
||||
|
||||
#decl(INIT_SRC1_SHMEM)
|
||||
|
||||
#ifdef INIT_SRC1_SHMEM_FLOAT
|
||||
fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) {
|
||||
for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let tile_n = elem_idx / TILE_K;
|
||||
let tile_k = elem_idx % TILE_K;
|
||||
let global_n = offset_n + tile_n;
|
||||
let global_k = k_outer + tile_k;
|
||||
let src1_idx = batch_offset + global_n * params.stride_11 + global_k;
|
||||
let src1_val = select(
|
||||
{{SRC1_TYPE}}(0.0),
|
||||
src1[src1_idx/{{VEC_SIZE}}],
|
||||
SRC1_TYPE(0.0),
|
||||
src1[src1_idx/VEC_SIZE],
|
||||
global_n < params.n && global_k < params.k);
|
||||
store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(INIT_SRC1_SHMEM)
|
||||
|
||||
#decl(INIT_SRC0_SHMEM_Q4_0)
|
||||
|
||||
#ifdef INIT_SRC0_SHMEM_Q4_0
|
||||
const BLOCK_SIZE = 32u;
|
||||
// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types.
|
||||
override BLOCKS_K = TILE_K/BLOCK_SIZE;
|
||||
|
|
@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
#enddecl(INIT_SRC0_SHMEM_Q4_0)
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -1,115 +1,19 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
enable f16;
|
||||
|
||||
#end(VARIANTS)
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
#ifdef VEC
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> {
|
||||
return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn]));
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
#ifdef SCALAR
|
||||
fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 {
|
||||
return f32(acc[tm][tn]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -130,14 +34,12 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
fn get_local_n(thread_id: u32) -> u32 {
|
||||
return thread_id / WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
|
@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 {
|
|||
return thread_id % WORKGROUP_SIZE_M;
|
||||
}
|
||||
|
||||
// TILE_M must be multiple of 4 for vec4 loads
|
||||
const TILE_M = {{WEBGPU_TILE_M}}u;
|
||||
const TILE_N = {{WEBGPU_TILE_N}}u;
|
||||
|
||||
override WORKGROUP_SIZE_M: u32;
|
||||
override WORKGROUP_SIZE_N: u32;
|
||||
override TILE_K: u32;
|
||||
|
||||
override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
|
||||
const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N;
|
||||
const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M;
|
||||
const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N;
|
||||
var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>;
|
||||
|
||||
@compute @workgroup_size(TOTAL_WORKGROUP_SIZE)
|
||||
|
|
@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
for (var tn = 0u; tn < TILE_N; tn++) {
|
||||
let global_col = output_col_base + tn;
|
||||
if (global_col < params.n) {
|
||||
for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) {
|
||||
for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) {
|
||||
let global_row = output_row_base + tm;
|
||||
if (global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm);
|
||||
dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,100 +1,12 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE" : "vec4<f32>",
|
||||
"SHMEM_TYPE" : "vec4<f16>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE" : "f32",
|
||||
"SHMEM_TYPE" : "f16",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"]
|
||||
}
|
||||
]
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
|
||||
#end(VARIANTS)
|
||||
#include "common_decls.tmpl"
|
||||
#include "mul_mat_decls.tmpl"
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(VEC)
|
||||
#ifdef VEC
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = vec4<f32>(
|
||||
f32(shmem[shmem_idx]),
|
||||
|
|
@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
|||
f32(shmem[shmem_idx + 3])
|
||||
);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
#ifdef SCALAR
|
||||
fn store_dst(shmem_idx: u32, dst_idx: u32) {
|
||||
dst[dst_idx] = f32(shmem[shmem_idx]);
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
diagnostic(off, chromium.subgroup_matrix_uniformity);
|
||||
enable f16;
|
||||
enable subgroups;
|
||||
enable chromium_experimental_subgroup_matrix;
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -138,36 +42,19 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed)
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
DECLS
|
||||
|
||||
// Note: These are string interpolated at build time, cannot use override constants due to limitations in
|
||||
// current Dawn version type definitions/matrix load requirements for constant memory sizes.
|
||||
const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u;
|
||||
const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u;
|
||||
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||
// runtime subgroup size is smaller.
|
||||
const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u;
|
||||
|
||||
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||
|
||||
const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u;
|
||||
const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u;
|
||||
|
||||
const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u;
|
||||
const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u;
|
||||
|
||||
const TILE_K = {{WEBGPU_TILE_K}}u;
|
||||
|
||||
const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the
|
||||
// runtime subgroup size is smaller.
|
||||
const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N;
|
||||
const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE;
|
||||
const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
|
@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE;
|
||||
let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE;
|
||||
|
||||
for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) {
|
||||
let local_row = idx % WG_TILE_STRIDE;
|
||||
let local_col = idx / WG_TILE_STRIDE;
|
||||
|
||||
|
|
@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
|
|||
|
||||
if (global_col < params.n && global_row < params.m) {
|
||||
let dst_idx = dst_batch_offset + global_col * params.m + global_row;
|
||||
store_dst(idx, dst_idx/{{VEC_SIZE}});
|
||||
store_dst(idx, dst_idx/VEC_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
|
|
@ -1,84 +1,17 @@
|
|||
#define(VARIANTS)
|
||||
[
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f32>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f32_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f32",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f32>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16_vec",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "vec4<f16>",
|
||||
"SRC1_TYPE" : "vec4<f16>",
|
||||
"DST_TYPE": "vec4<f32>",
|
||||
"VEC_SIZE" : 4,
|
||||
},
|
||||
"DECLS": ["VEC", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "f16_f16",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f16",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["SCALAR", "MUL_ACC_FLOAT"]
|
||||
},
|
||||
{
|
||||
"SHADER_SUFFIX": "q4_0_f32",
|
||||
"REPLS": {
|
||||
"SRC0_TYPE" : "f16",
|
||||
"SRC1_TYPE" : "f32",
|
||||
"DST_TYPE": "f32",
|
||||
"VEC_SIZE" : 1,
|
||||
},
|
||||
"DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
enable f16;
|
||||
|
||||
#define(DECLS)
|
||||
#include "common_decls.tmpl"
|
||||
|
||||
#decl(VEC)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
return f32(dot({{SRC1_TYPE}}(src0_val), src1_val));
|
||||
#ifdef VEC
|
||||
|
||||
#define VEC_SIZE 4
|
||||
#define DST_TYPE vec4<f32>
|
||||
#define SRC0_TYPE vec4<SRC0_INNER_TYPE>
|
||||
#define SRC1_TYPE vec4<SRC1_INNER_TYPE>
|
||||
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(dot(SRC1_TYPE(src0_val), src1_val));
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> vec4<f32> {
|
||||
|
|
@ -87,33 +20,37 @@ fn store_val(group_base: u32) -> vec4<f32> {
|
|||
partial_sums[group_base + THREADS_PER_OUTPUT * 2],
|
||||
partial_sums[group_base + THREADS_PER_OUTPUT * 3]);
|
||||
}
|
||||
#enddecl(VEC)
|
||||
#endif
|
||||
|
||||
#decl(SCALAR)
|
||||
fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 {
|
||||
#ifdef SCALAR
|
||||
|
||||
#define VEC_SIZE 1
|
||||
#define DST_TYPE f32
|
||||
#define SRC0_TYPE SRC0_INNER_TYPE
|
||||
#define SRC1_TYPE SRC1_INNER_TYPE
|
||||
|
||||
fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 {
|
||||
return f32(src0_val) * f32(src1_val);
|
||||
}
|
||||
|
||||
fn store_val(group_base: u32) -> f32 {
|
||||
return partial_sums[group_base];
|
||||
}
|
||||
#enddecl(SCALAR)
|
||||
|
||||
#decl(MUL_ACC_FLOAT)
|
||||
#endif
|
||||
|
||||
#ifdef MUL_ACC_FLOAT
|
||||
fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
||||
var local_sum = 0.0;
|
||||
for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) {
|
||||
let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}];
|
||||
let b = shared_vector[i / {{VEC_SIZE}}];
|
||||
for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) {
|
||||
let a = src0[(idx_base + k_outer + i) / VEC_SIZE];
|
||||
let b = shared_vector[i / VEC_SIZE];
|
||||
local_sum += inner_dot(a, b);
|
||||
}
|
||||
return local_sum;
|
||||
}
|
||||
#endif
|
||||
|
||||
#enddecl(MUL_ACC_FLOAT)
|
||||
|
||||
#decl(MUL_ACC_Q4_0)
|
||||
#ifdef MUL_ACC_Q4_0
|
||||
|
||||
const BLOCK_SIZE = 32;
|
||||
const NQ = 16u; // number of weights per thread
|
||||
|
|
@ -145,15 +82,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 {
|
|||
}
|
||||
return local_sum;
|
||||
}
|
||||
|
||||
#enddecl(MUL_ACC_Q4_0)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
enable f16;
|
||||
|
||||
DECLS
|
||||
#endif
|
||||
|
||||
struct MulMatParams {
|
||||
offset_src0: u32,
|
||||
|
|
@ -174,22 +103,20 @@ struct MulMatParams {
|
|||
broadcast3: u32
|
||||
};
|
||||
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K)
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed)
|
||||
// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included
|
||||
@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns
|
||||
@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed)
|
||||
@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed)
|
||||
|
||||
@group(0) @binding(3) var<uniform> params: MulMatParams;
|
||||
|
||||
override WORKGROUP_SIZE: u32;
|
||||
override TILE_K: u32;
|
||||
override OUTPUTS_PER_WG: u32;
|
||||
override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG;
|
||||
const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG;
|
||||
|
||||
// Shared memory for collaborative loading and reduction
|
||||
var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile
|
||||
var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction
|
||||
var<workgroup> shared_vector: array<SRC1_TYPE, TILE_K/VEC_SIZE>; // Cache vector tile
|
||||
var<workgroup> partial_sums: array<f32, WG_SIZE>; // For reduction
|
||||
|
||||
@compute @workgroup_size(WORKGROUP_SIZE)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) wg_id: vec3<u32>,
|
||||
|
|
@ -232,8 +159,8 @@ fn main(
|
|||
let tile_size = min(TILE_K, params.k - k_tile);
|
||||
|
||||
// Cooperatively load vector tile into shared memory (all threads)
|
||||
for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) {
|
||||
shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}];
|
||||
for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) {
|
||||
shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE];
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
|
@ -250,7 +177,7 @@ fn main(
|
|||
workgroupBarrier();
|
||||
let group_base = thread_group * THREADS_PER_OUTPUT;
|
||||
let thread_base = group_base + thread_in_group;
|
||||
var offset = THREADS_PER_OUTPUT / 2;
|
||||
var offset: u32 = THREADS_PER_OUTPUT / 2;
|
||||
while (offset > 0) {
|
||||
if (thread_in_group < offset) {
|
||||
partial_sums[thread_base] += partial_sums[thread_base + offset];
|
||||
|
|
@ -260,8 +187,8 @@ fn main(
|
|||
}
|
||||
|
||||
// Store back to global memory
|
||||
if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) {
|
||||
dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base);
|
||||
if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) {
|
||||
dst[dst_idx / VEC_SIZE] = store_val(group_base);
|
||||
}
|
||||
}
|
||||
#end(SHADER)
|
||||
|
||||
|
|
@ -1,21 +1,11 @@
|
|||
#define(VARIANTS)
|
||||
#ifdef INPLACE
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
[
|
||||
{
|
||||
"SHADER_NAME": "scale_f32",
|
||||
"DECLS": ["NOT_INPLACE"]
|
||||
},
|
||||
{
|
||||
"SHADER_NAME": "scale_f32_inplace",
|
||||
"DECLS": ["INPLACE"]
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(DECLS)
|
||||
|
||||
#decl(NOT_INPLACE)
|
||||
fn store_scale(val: f32, offset: u32) {
|
||||
src[offset] = val;
|
||||
}
|
||||
#else
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> dst: array<f32>;
|
||||
|
||||
|
|
@ -25,20 +15,7 @@ var<uniform> params: Params;
|
|||
fn store_scale(val: f32, offset: u32) {
|
||||
dst[offset] = val;
|
||||
}
|
||||
#enddecl(NOT_INPLACE)
|
||||
|
||||
#decl(INPLACE)
|
||||
@group(0) @binding(1)
|
||||
var<uniform> params: Params;
|
||||
|
||||
fn store_scale(val: f32, offset: u32) {
|
||||
src[offset] = val;
|
||||
}
|
||||
#enddecl(INPLACE)
|
||||
|
||||
#end(DECLS)
|
||||
|
||||
#define(SHADER)
|
||||
#endif
|
||||
|
||||
struct Params {
|
||||
offset_src: u32,
|
||||
|
|
@ -65,10 +42,7 @@ struct Params {
|
|||
@group(0) @binding(0)
|
||||
var<storage, read_write> src: array<f32>;
|
||||
|
||||
DECLS
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
@compute @workgroup_size(WG_SIZE)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x >= params.ne) {
|
||||
return;
|
||||
|
|
@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|||
|
||||
store_scale(src[i_src] * params.scale + params.bias, i_dst);
|
||||
}
|
||||
#end(SHADER)
|
||||
Loading…
Reference in a new issue