Closed Bug 1877734 Opened 2 years ago Closed 1 month ago

Investigate on MatMul in onnx runtime

Categories

(Core :: Machine Learning: General, task)

task

Tracking

()

RESOLVED WONTFIX

People

(Reporter: tarek, Unassigned)

References

Details

(Whiteboard: [genai])

For comparison, https://gregtatum.github.io/taskcluster-tools/benchmark.html takes 6.67 seconds on my machine on Release, while it takes 2.49 seconds on Nightly.

Blocks: 1883591
Group: mozilla-employee-confidential
Whiteboard: [genai]

We have discussed this a bit during mozweek.

ONNX has various SIMD backends (https://github.com/microsoft/onnxruntime/tree/44dcc3aafd4e308a1847552335ad85db1a1ec5e7/onnxruntime/core/mlas/lib), one of them is built for WASM (https://github.com/microsoft/onnxruntime/tree/44dcc3aafd4e308a1847552335ad85db1a1ec5e7/onnxruntime/core/mlas/lib/wasm_simd). The WASM one can only use a subset of SIMD instructions, so it can't be as fast as the native one.

We have two options:

  1. Keep ONNX in WASM, patch the WASM to use our native MatMul functions instead of the WASM MatMul functions. This is the approach we are taking for Bergamot (see https://github.com/browsermt/bergamot-translator/blob/main/wasm/patch-artifacts-import-gemm-module.sh and https://github.com/browsermt/bergamot-translator/blob/main/wasm/import-gemm-module.js and https://searchfox.org/mozilla-central/source/js/src/intgemm).

  2. Compile ONNX to native instead of WASM. This means we'll get the MatMul for free since it is already in ONNX native. The main cons of this approach is that we'd need to patch transformer.js to have it talk to a native ONNX instead of a WASM ONNX, which would be maintenance burden (unless we work with transformer.js to have this Firefox layer be in the main repo, so we don't need to maintain a fork).

To have a better idea how much these optimizations could help and the difference in potential speed between the two options above, we could:

  • compare Firefox translations performance with/without the native MatMul;
  • compare ONNX performance between WASM and native.

(In reply to Marco Castelluccio [:marco] from comment #2)

  • compare Firefox translations performance with/without the native MatMul;

To do this, Greg provided this patch:

diff --git a/toolkit/components/translations/bergamot-translator/bergamot-translator.js b/toolkit/components/translations/bergamot-translator/bergamot-translator.js
index c50eb32dffc12..e93eb521d7451 100644
--- a/toolkit/components/translations/bergamot-translator/bergamot-translator.js
+++ b/toolkit/components/translations/bergamot-translator/bergamot-translator.js
@@ -3475,9 +3475,7 @@ function loadBergamot(Module) {
     const OPTIMIZED_GEMM = "mozIntGemm";
 
     const optimizedGemmModule = WebAssembly[OPTIMIZED_GEMM];
-    if (!optimizedGemmModule) {
-      return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);
-    }
+    return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);
 
     const optimizedGemmModuleExports = new WebAssembly.Instance(
       optimizedGemmModule(),

We can build a normal Firefox and a Firefox with this patch and then test the perf difference on Windows (we still don't have MatMul acceleration for Mac, see bug 1868104).

  • return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);

Just to mention, IIRC the fallbackGemm has subpar performance: it goes from Wasm into JS then back to Wasm, hence "fallback".

To have a better idea how much these optimizations could help

Pay attention which options and backends (int8, f32, f64, etc.) are used during comparison too, e.g. clang has auto-vectorization so int8 maybe fast too. Also evaluate different platforms, especially x86_64 and aarch64.

(In reply to Yury Delendik (:yury) from comment #4)

  • return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);

Just to mention, IIRC the fallbackGemm has subpar performance: it goes from Wasm into JS then back to Wasm, hence "fallback".

Do you know why the fallback is doing that and not just using Wasm?

(In reply to Marco Castelluccio [:marco] from comment #5)

(In reply to Yury Delendik (:yury) from comment #4)

  • return fallbackGemm(GEMM_TO_FALLBACK_FUNCTIONS_MAP);

Just to mention, IIRC the fallbackGemm has subpar performance: it goes from Wasm into JS then back to Wasm, hence "fallback".

Do you know why the fallback is doing that and not just using Wasm?

The initial design was to use inversion of control pattern to decouple matrix multiplication from main logic. It was never fully implemented. Fallback wasm module/methods, similar to mozIntGemm builtin, suppose to be implemented in the separate module and removed from the main logic. But looks like it never happened. If there are no plans to use mozIntGemm builtins, you can remove this fallback interface and use fallback wasm logic directly.

Assignee: nobody → tziade

This is being implemented in bug 1936320.

Assignee: tarek → nobody

At this point we're planning on consolidating on onnx-native as a better approach. I'm going to wontfix this.

Status: NEW → RESOLVED
Closed: 1 month ago
Resolution: --- → WONTFIX
You need to log in before you can comment on or make changes to this bug.