Skip to content

Commit db460b7

Browse files
committed
wip : WASM 128-bit SIMD support
1 parent e905c6f commit db460b7

5 files changed

Lines changed: 189 additions & 13 deletions

File tree

CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,15 @@ else()
123123
if (MSVC)
124124
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /arch:AVX2 /D_CRT_SECURE_NO_WARNINGS=1")
125125
else()
126-
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
126+
if (EMSCRIPTEN)
127+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
128+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
129+
else()
130+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx -mavx2 -mfma -mf16c")
131+
endif()
127132
endif()
128133
endif()
129134

130-
if (EMSCRIPTEN)
131-
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -pthread -msimd128")
132-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
133-
endif()
134-
135135
# whisper - this is the main library of the project
136136

137137
set(TARGET whisper)

bindings/javascript/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ if (WHISPER_WASM_SINGLE_FILE)
2121
)
2222
endif()
2323

24+
#-s TOTAL_MEMORY=536870912 \
2425
set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \
2526
--bind \
2627
-s MODULARIZE=1 \
2728
-s ASSERTIONS=1 \
2829
-s USE_PTHREADS=1 \
29-
-s PTHREAD_POOL_SIZE=8 \
30-
-s TOTAL_MEMORY=536870912 \
30+
-s PTHREAD_POOL_SIZE=9 \
31+
-s ALLOW_MEMORY_GROWTH=1 \
3132
-s FORCE_FILESYSTEM=1 \
3233
-s EXPORT_NAME=\"'whisper_factory'\" \
3334
${EXTRA_FLAGS} \

bindings/javascript/emscripten.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <emscripten/bind.h>
55

66
#include <vector>
7+
#include <thread>
78

89
std::vector<struct whisper_context *> g_contexts(4, nullptr);
910

@@ -47,7 +48,7 @@ EMSCRIPTEN_BINDINGS(whisper) {
4748
params.print_special_tokens = false;
4849
params.translate = false;
4950
params.language = "en";
50-
params.n_threads = 4;
51+
params.n_threads = std::min(8, (int) std::thread::hardware_concurrency());
5152
params.offset_ms = 0;
5253

5354
std::vector<float> pcmf32;

bindings/javascript/whisper.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ggml.c

Lines changed: 177 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,11 @@ ggml_fp16_t ggml_fp32_to_fp16(float x) {
7373

7474
#else
7575

76+
#ifdef __wasm_simd128__
77+
#include <wasm_simd128.h>
78+
#else
7679
#include <immintrin.h>
80+
#endif
7781

7882
// FP16 <-> FP32
7983
// ref: https://github.com/Maratyszcza/FP16
@@ -288,7 +292,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
288292
sumf += x[i]*y[i];
289293
}
290294
#elif defined(__AVX2__)
291-
// AVX 256-bit (unroll 4)
295+
// AVX 256-bit
292296
const int n32 = (n & ~31);
293297

294298
__m256 sum0 = _mm256_setzero_ps();
@@ -330,6 +334,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
330334
for (int i = n32; i < n; ++i) {
331335
sumf += x[i]*y[i];
332336
}
337+
#elif defined(__wasm_simd128__)
338+
// WASM 128-bit
339+
const int n16 = (n & ~15);
340+
341+
v128_t sum0 = wasm_f32x4_splat(0);
342+
v128_t sum1 = wasm_f32x4_splat(0);
343+
v128_t sum2 = wasm_f32x4_splat(0);
344+
v128_t sum3 = wasm_f32x4_splat(0);
345+
346+
v128_t x0, x1, x2, x3;
347+
v128_t y0, y1, y2, y3;
348+
349+
for (int i = 0; i < n16; i += 16) {
350+
x0 = wasm_v128_load(x + i + 0);
351+
x1 = wasm_v128_load(x + i + 4);
352+
x2 = wasm_v128_load(x + i + 8);
353+
x3 = wasm_v128_load(x + i + 12);
354+
355+
y0 = wasm_v128_load(y + i + 0);
356+
y1 = wasm_v128_load(y + i + 4);
357+
y2 = wasm_v128_load(y + i + 8);
358+
y3 = wasm_v128_load(y + i + 12);
359+
360+
sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
361+
sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
362+
sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
363+
sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
364+
}
365+
366+
sum0 = wasm_f32x4_add(sum0, sum1);
367+
sum2 = wasm_f32x4_add(sum2, sum3);
368+
sum0 = wasm_f32x4_add(sum0, sum2);
369+
370+
sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
371+
372+
// leftovers
373+
for (int i = n16; i < n; ++i) {
374+
sumf += x[i]*y[i];
375+
}
333376
#else
334377
// scalar
335378
for (int i = 0; i < n; ++i) {
@@ -446,7 +489,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
446489
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
447490
}
448491
#elif defined(__AVX2__)
449-
// AVX 256-bit (unroll 4)
492+
// AVX 256-bit
450493
const int n32 = (n & ~31);
451494

452495
__m256 sum0 = _mm256_setzero_ps();
@@ -489,6 +532,54 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
489532
//GGML_ASSERT(false);
490533
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
491534
}
535+
#elif defined(__wasm_simd128__)
536+
// WASM 128-bit
537+
const int n16 = (n & ~15);
538+
539+
v128_t sum0 = wasm_f32x4_splat(0.0f);
540+
v128_t sum1 = wasm_f32x4_splat(0.0f);
541+
v128_t sum2 = wasm_f32x4_splat(0.0f);
542+
v128_t sum3 = wasm_f32x4_splat(0.0f);
543+
544+
v128_t x0, x1, x2, x3;
545+
v128_t y0, y1, y2, y3;
546+
547+
float tx[16];
548+
float ty[16];
549+
550+
for (int i = 0; i < n16; i += 16) {
551+
for (int k = 0; k < 16; ++k) {
552+
tx[k] = ggml_fp16_to_fp32(x[i + k]);
553+
ty[k] = ggml_fp16_to_fp32(y[i + k]);
554+
}
555+
556+
x0 = wasm_v128_load(tx + 0);
557+
x1 = wasm_v128_load(tx + 4);
558+
x2 = wasm_v128_load(tx + 8);
559+
x3 = wasm_v128_load(tx + 12);
560+
561+
y0 = wasm_v128_load(ty + 0);
562+
y1 = wasm_v128_load(ty + 4);
563+
y2 = wasm_v128_load(ty + 8);
564+
y3 = wasm_v128_load(ty + 12);
565+
566+
sum0 = wasm_f32x4_add(sum0, wasm_f32x4_mul(x0, y0));
567+
sum1 = wasm_f32x4_add(sum1, wasm_f32x4_mul(x1, y1));
568+
sum2 = wasm_f32x4_add(sum2, wasm_f32x4_mul(x2, y2));
569+
sum3 = wasm_f32x4_add(sum3, wasm_f32x4_mul(x3, y3));
570+
}
571+
572+
sum0 = wasm_f32x4_add(sum0, sum1);
573+
sum2 = wasm_f32x4_add(sum2, sum3);
574+
sum0 = wasm_f32x4_add(sum0, sum2);
575+
576+
sumf = wasm_f32x4_extract_lane(sum0, 0) + wasm_f32x4_extract_lane(sum0, 1) + wasm_f32x4_extract_lane(sum0, 2) + wasm_f32x4_extract_lane(sum0, 3);
577+
578+
// leftovers
579+
for (int i = n16; i < n; ++i) {
580+
//GGML_ASSERT(false);
581+
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
582+
}
492583
#else
493584
for (int i = 0; i < n; ++i) {
494585
sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]);
@@ -535,7 +626,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
535626
y[i] += x[i]*v;
536627
}
537628
#elif defined(__AVX2__)
538-
// AVX 256-bit (unroll 4)
629+
// AVX 256-bit
539630
const int n32 = (n & ~31);
540631

541632
const __m256 v4 = _mm256_set1_ps(v);
@@ -569,6 +660,41 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
569660
for (int i = n32; i < n; ++i) {
570661
y[i] += x[i]*v;
571662
}
663+
#elif defined(__wasm_simd128__)
664+
// WASM SIMD 128-bit
665+
const int n16 = (n & ~15);
666+
667+
const v128_t v4 = wasm_f32x4_splat(v);
668+
669+
v128_t x0, x1, x2, x3;
670+
v128_t y0, y1, y2, y3;
671+
672+
for (int i = 0; i < n16; i += 16) {
673+
x0 = wasm_v128_load(x + i + 0);
674+
x1 = wasm_v128_load(x + i + 4);
675+
x2 = wasm_v128_load(x + i + 8);
676+
x3 = wasm_v128_load(x + i + 12);
677+
678+
y0 = wasm_v128_load(y + i + 0);
679+
y1 = wasm_v128_load(y + i + 4);
680+
y2 = wasm_v128_load(y + i + 8);
681+
y3 = wasm_v128_load(y + i + 12);
682+
683+
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
684+
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
685+
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
686+
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
687+
688+
wasm_v128_store(y + i + 0, y0);
689+
wasm_v128_store(y + i + 4, y1);
690+
wasm_v128_store(y + i + 8, y2);
691+
wasm_v128_store(y + i + 12, y3);
692+
}
693+
694+
// leftovers
695+
for (int i = n16; i < n; ++i) {
696+
y[i] += x[i]*v;
697+
}
572698
#else
573699
// scalar
574700
for (int i = 0; i < n; ++i) {
@@ -696,6 +822,54 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_
696822
GGML_ASSERT(false);
697823
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
698824
}
825+
#elif defined(__wasm_simd128__)
826+
// WASM SIMD 128-bit
827+
const int n16 = (n & ~15);
828+
829+
const v128_t v4 = wasm_f32x4_splat(v);
830+
831+
v128_t x0, x1, x2, x3;
832+
v128_t y0, y1, y2, y3;
833+
834+
float tx[16];
835+
float ty[16];
836+
837+
for (int i = 0; i < n16; i += 16) {
838+
for (int k = 0; k < 16; ++k) {
839+
tx[k] = ggml_fp16_to_fp32(x[i + k]);
840+
ty[k] = ggml_fp16_to_fp32(y[i + k]);
841+
}
842+
843+
x0 = wasm_v128_load(tx + 0);
844+
x1 = wasm_v128_load(tx + 4);
845+
x2 = wasm_v128_load(tx + 8);
846+
x3 = wasm_v128_load(tx + 12);
847+
848+
y0 = wasm_v128_load(ty + 0);
849+
y1 = wasm_v128_load(ty + 4);
850+
y2 = wasm_v128_load(ty + 8);
851+
y3 = wasm_v128_load(ty + 12);
852+
853+
y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4));
854+
y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4));
855+
y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4));
856+
y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4));
857+
858+
wasm_v128_store(ty + 0, y0);
859+
wasm_v128_store(ty + 4, y1);
860+
wasm_v128_store(ty + 8, y2);
861+
wasm_v128_store(ty + 12, y3);
862+
863+
for (int k = 0; k < 16; ++k) {
864+
y[i + k] = ggml_fp32_to_fp16(ty[k]);
865+
}
866+
}
867+
868+
// leftovers
869+
for (int i = n16; i < n; ++i) {
870+
GGML_ASSERT(false);
871+
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);
872+
}
699873
#else
700874
for (int i = 0; i < n; ++i) {
701875
y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v);

0 commit comments

Comments
 (0)