diff --git a/3rdparty/llama.cpp b/3rdparty/llama.cpp
index 1f86f058d..b1b73dae1 160000
--- a/3rdparty/llama.cpp
+++ b/3rdparty/llama.cpp
@@ -1 +1 @@
-Subproject commit 1f86f058de0c3f4098dedae2ae8653c335c868a1
+Subproject commit b1b73dae1184a9c8665dc8f015205d1ee150ea58
diff --git a/README.md b/README.md
index 3bb25596e..ad36ae465 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
[
](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
-Try it out via this [demo](https://demo-bitnet-h0h8hcfqeqhrf5gf.canadacentral-01.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md).
+Try it out via this [demo](https://demo-bitnet-h0h8hcfqeqhrf5gf.canadacentral-01.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/EthanPeng-2048/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/EthanPeng-2048/BitNet/blob/main/gpu/README.md).
bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next).
@@ -22,14 +22,14 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2:
https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1
## What's New:
-- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/microsoft/BitNet/blob/main/src/README.md) 
-- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md)
+- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/EthanPeng-2048/BitNet/blob/main/src/README.md) 
+- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/EthanPeng-2048/BitNet/blob/main/gpu/README.md)
- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T)
- 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880)
- 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965)
- 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144)
- 10/17/2024 bitnet.cpp 1.0 released.
-- 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf)
+- 03/21/2024 [The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ](https://github.com/EthanPeng-2048/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf)
- 02/27/2024 [The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)
- 10/17/2023 [BitNet: Scaling 1-bit Transformers for Large Language Models](https://arxiv.org/abs/2310.11453)
@@ -180,7 +180,7 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp)
1. Clone the repo
```bash
-git clone --recursive https://github.com/microsoft/BitNet.git
+git clone --recursive https://github.com/EthanPeng-2048/BitNet.git
cd BitNet
```
2. Install the dependencies
diff --git a/include/bitnet-lut-kernels.h b/include/bitnet-lut-kernels.h
new file mode 100644
index 000000000..6ac0e765c
--- /dev/null
+++ b/include/bitnet-lut-kernels.h
@@ -0,0 +1,1171 @@
+#if defined(GGML_BITNET_X86_TL2)
+#include "ggml-bitnet.h"
+#include
+#include
+#define GGML_BITNET_MAX_NODES 8192
+static bool initialized = false;
+static bitnet_tensor_extra * bitnet_tensor_extras = nullptr;
+static size_t bitnet_tensor_extras_index = 0;
+static void * aligned_malloc(size_t size) {
+#if defined(_WIN32)
+ return _aligned_malloc(size, 64);
+#else
+ void * ptr = nullptr;
+ posix_memalign(&ptr, 64, size);
+ return ptr;
+#endif
+}
+
+static void aligned_free(void * ptr) {
+#if defined(_WIN32)
+ _aligned_free(ptr);
+#else
+ free(ptr);
+#endif
+}
+#define BK2 32
+#if defined __AVX2__
+inline void _mm256_merge_epi32(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)
+{
+ __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));
+ __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));
+ *vl = _mm256_unpacklo_epi32(va, vb);
+ *vh = _mm256_unpackhi_epi32(va, vb);
+}
+inline void _mm256_merge_epi64(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)
+{
+ __m256i va = _mm256_permute4x64_epi64(v0, _MM_SHUFFLE(3, 1, 2, 0));
+ __m256i vb = _mm256_permute4x64_epi64(v1, _MM_SHUFFLE(3, 1, 2, 0));
+ *vl = _mm256_unpacklo_epi64(va, vb);
+ *vh = _mm256_unpackhi_epi64(va, vb);
+}
+inline void _mm256_merge_si128(const __m256i v0, const __m256i v1, __m256i *vl, __m256i *vh)
+{
+ *vl = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 2, 0, 0));
+ *vh = _mm256_permute2x128_si256(v0, v1, _MM_SHUFFLE(0, 3, 0, 1));
+}
+inline void Transpose_8_8(
+ __m256i *v0,
+ __m256i *v1,
+ __m256i *v2,
+ __m256i *v3,
+ __m256i *v4,
+ __m256i *v5,
+ __m256i *v6,
+ __m256i *v7)
+{
+ __m256i w0, w1, w2, w3, w4, w5, w6, w7;
+ __m256i x0, x1, x2, x3, x4, x5, x6, x7;
+ _mm256_merge_epi32(*v0, *v1, &w0, &w1);
+ _mm256_merge_epi32(*v2, *v3, &w2, &w3);
+ _mm256_merge_epi32(*v4, *v5, &w4, &w5);
+ _mm256_merge_epi32(*v6, *v7, &w6, &w7);
+ _mm256_merge_epi64(w0, w2, &x0, &x1);
+ _mm256_merge_epi64(w1, w3, &x2, &x3);
+ _mm256_merge_epi64(w4, w6, &x4, &x5);
+ _mm256_merge_epi64(w5, w7, &x6, &x7);
+ _mm256_merge_si128(x0, x4, v0, v1);
+ _mm256_merge_si128(x1, x5, v2, v3);
+ _mm256_merge_si128(x2, x6, v4, v5);
+ _mm256_merge_si128(x3, x7, v6, v7);
+}
+#endif
+inline int32_t per_tensor_quant(int k, void* lut_scales_, void* b_) {
+ bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
+ bitnet_float_type* b = (bitnet_float_type*)b_;
+#if defined __AVX2__
+ __m256 max_vec = _mm256_set1_ps(0.f);
+ const __m256 vec_sign = _mm256_set1_ps(-0.0f);
+ for (int i = 0; i < k / 8; i++) {
+ __m256 vec_b = _mm256_loadu_ps(b + i * 8);
+ __m256 vec_babs = _mm256_andnot_ps(vec_sign, vec_b);
+ max_vec = _mm256_max_ps(vec_babs, max_vec);
+ }
+ __m128 max1 = _mm_max_ps(_mm256_extractf128_ps(max_vec, 1), _mm256_castps256_ps128(max_vec));
+ max1 = _mm_max_ps(max1, _mm_movehl_ps(max1, max1));
+ max1 = _mm_max_ss(max1, _mm_movehdup_ps(max1));
+ float scales = 127 / _mm_cvtss_f32(max1);
+ *lut_scales = scales;
+#endif
+ return 0;
+}
+inline int32_t partial_max_reset(int32_t bs, void* lut_scales_) {
+ bitnet_float_type* lut_scales = (bitnet_float_type*)lut_scales_;
+ #pragma unroll
+ for (int i=0; i< bs; i++) {
+ lut_scales[i] = 0.0;
+ }
+ return 0;
+}
+template
+inline int32_t three_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {
+#if defined __AVX2__
+ __m256i vec_lut[16];
+ const __m256i vec_bi = _mm256_set_epi32(84, 72, 60, 48, 36, 24, 12, 0);
+ float scales = *lut_scales;
+ __m256i shuffle_mask = _mm256_set_epi8(
+ 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,
+ 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,
+ 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,
+ 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00
+ );
+#pragma unroll
+ for (int k = 0; k < act_k / 24; ++k) {
+ __m256 vec_b0 = _mm256_i32gather_ps(b + k * 24 + 0, vec_bi, 1);
+ __m256 vec_b1 = _mm256_i32gather_ps(b + k * 24 + 1, vec_bi, 1);
+ __m256 vec_b2 = _mm256_i32gather_ps(b + k * 24 + 2, vec_bi, 1);
+
+ __m256i vec_b0i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ __m256i vec_b1i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ __m256i vec_b2i = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b2, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+
+ vec_lut[15] = _mm256_setzero_si256();
+ vec_lut[14] = _mm256_setzero_si256();
+ vec_lut[13] = vec_b0i;
+ vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b1i);
+ vec_lut[13] = _mm256_add_epi32(vec_lut[13], vec_b2i);
+ vec_lut[12] = vec_b0i;
+ vec_lut[12] = _mm256_add_epi32(vec_lut[12], vec_b1i);
+ vec_lut[11] = vec_b0i;
+ vec_lut[11] = _mm256_add_epi32(vec_lut[11], vec_b1i);
+ vec_lut[11] = _mm256_sub_epi32(vec_lut[11], vec_b2i);
+ vec_lut[10] = vec_b0i;
+ vec_lut[10] = _mm256_add_epi32(vec_lut[10], vec_b2i);
+ vec_lut[9] = vec_b0i;
+ vec_lut[8] = vec_b0i;
+ vec_lut[8] = _mm256_sub_epi32(vec_lut[8], vec_b2i);
+ vec_lut[7] = vec_b0i;
+ vec_lut[7] = _mm256_sub_epi32(vec_lut[7], vec_b1i);
+ vec_lut[7] = _mm256_add_epi32(vec_lut[7], vec_b2i);
+ vec_lut[6] = vec_b0i;
+ vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1i);
+ vec_lut[5] = vec_b0i;
+ vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b1i);
+ vec_lut[5] = _mm256_sub_epi32(vec_lut[5], vec_b2i);
+ vec_lut[4] = vec_b1i;
+ vec_lut[4] = _mm256_add_epi32(vec_lut[4], vec_b2i);
+ vec_lut[3] = vec_b1i;
+ vec_lut[2] = vec_b1i;
+ vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b2i);
+ vec_lut[1] = vec_b2i;
+ vec_lut[0] = _mm256_setzero_si256();
+ __m256i ix[16];
+
+#pragma unroll
+ for (int g = 0; g < 16; ++g) {
+ ix[g] = vec_lut[g];
+ }
+
+ Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));
+ Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));
+
+#pragma unroll
+ for (int g = 0; g < 8; ++g) {
+ ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);
+ ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));
+ ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);
+ ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));
+ }
+ int8_t* qlut_i8 = reinterpret_cast(qlut);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);
+
+ }
+
+ *lut_scales = scales;
+#endif
+ return 0;
+}
+
+template
+inline int32_t two_lut_ctor(int8_t* qlut, bitnet_float_type* b, bitnet_float_type* lut_scales) {
+#if defined __AVX2__
+ __m256i vec_lut[16];
+ const __m256i vec_bi = _mm256_set_epi32(56, 48, 40, 32, 24, 16, 8, 0);
+ float scales = *lut_scales;
+ __m256i shuffle_mask = _mm256_set_epi8(
+ 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,
+ 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00,
+ 0x0f, 0x0d, 0x0b, 0x09, 0x07, 0x05, 0x03, 0x01,
+ 0x0e, 0x0c, 0x0a, 0x08, 0x06, 0x04, 0x02, 0x00
+ );
+#pragma unroll
+ for (int k = 0; k < act_k / 16; ++k) {
+ __m256 vec_b0f = _mm256_i32gather_ps(b + k * 16 + 0, vec_bi, 1);
+ __m256 vec_b1f = _mm256_i32gather_ps(b + k * 16 + 1, vec_bi, 1);
+
+ __m256i vec_b0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b0f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ __m256i vec_b1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(vec_b1f, _mm256_set1_ps(scales)), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ vec_lut[15] = _mm256_setzero_si256();
+ vec_lut[14] = _mm256_setzero_si256();
+ vec_lut[13] = _mm256_setzero_si256();
+ vec_lut[12] = _mm256_setzero_si256();
+ vec_lut[11] = _mm256_setzero_si256();
+ vec_lut[10] = _mm256_setzero_si256();
+ vec_lut[9] = _mm256_setzero_si256();
+ vec_lut[8] = vec_b0;
+ vec_lut[8] = _mm256_add_epi32(vec_lut[8], vec_b1);
+ vec_lut[7] = vec_b0;
+ vec_lut[6] = vec_b0;
+ vec_lut[6] = _mm256_sub_epi32(vec_lut[6], vec_b1);
+ vec_lut[5] = vec_b1;
+ vec_lut[4] = _mm256_setzero_si256();
+ vec_lut[3] = _mm256_setzero_si256();
+ vec_lut[3] = _mm256_sub_epi32(vec_lut[3], vec_b1);
+ vec_lut[2] = _mm256_setzero_si256();
+ vec_lut[2] = _mm256_sub_epi32(vec_lut[2], vec_b0);
+ vec_lut[2] = _mm256_add_epi32(vec_lut[2], vec_b1);
+ vec_lut[1] = _mm256_setzero_si256();
+ vec_lut[1] = _mm256_sub_epi32(vec_lut[1], vec_b0);
+ vec_lut[0] = _mm256_setzero_si256();
+ vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b0);
+ vec_lut[0] = _mm256_sub_epi32(vec_lut[0], vec_b1);
+
+ __m256i ix[16];
+#pragma unroll
+ for (int g = 0; g < 16; ++g) {
+ ix[g] = vec_lut[g];
+ }
+
+ Transpose_8_8(&(ix[0]), &(ix[1]), &(ix[2]), &(ix[3]), &(ix[4]), &(ix[5]),&(ix[6]), &(ix[7]));
+ Transpose_8_8(&(ix[8]), &(ix[9]), &(ix[10]), &(ix[11]), &(ix[12]), &(ix[13]),&(ix[14]), &(ix[15]));
+
+#pragma unroll
+ for (int g = 0; g < 8; ++g) {
+ ix[g] = _mm256_packs_epi32(ix[g], ix[g + 8]);
+ ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));
+ ix[g] = _mm256_shuffle_epi8(ix[g], shuffle_mask);
+ ix[g] = _mm256_permute4x64_epi64(ix[g], _MM_SHUFFLE(3, 1, 2, 0));
+ }
+
+ int8_t* qlut_i8 = reinterpret_cast(qlut);
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 0 * 32 + 0), ix[0]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 1 * 32 + 0), ix[1]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 2 * 32 + 0), ix[2]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 3 * 32 + 0), ix[3]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 4 * 32 + 0), ix[4]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 5 * 32 + 0), ix[5]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 6 * 32 + 0), ix[6]);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(qlut_i8 + k * 256 + 7 * 32 + 0), ix[7]);
+
+ }
+ *lut_scales = scales;
+#endif
+ return 0;
+}
+static bool is_type_supported(enum ggml_type type) {
+ if (type == GGML_TYPE_Q4_0 ||
+ type == GGML_TYPE_TL2) {
+ return true;
+ } else {
+ return false;
+ }
+}
+#include
+
+#define BM3200_8640 160
+#define BBK3200_8640 96
+template
+inline void three_tbl_impl_3200_8640(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);
+ const __m256i vec_zero = _mm256_set1_epi8(0x00);
+ const __m256i vec_one = _mm256_set1_epi8(0xff);
+ const int KK = BBK3200_8640 / 3;
+#pragma unroll
+ for (int i = 0; i < BM3200_8640; i += 32) {
+ __m256i vec_as[KK / 2];
+ __m256i vec_signs[KK / 8];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+ #pragma unroll
+ for (int as = 0; as < KK / 8; as++) {
+ vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ __m256i vec_sign = vec_signs[k];
+ __m256i vec_a_0 = vec_as[k * 4 + 0];
+ __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);
+ __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);
+ __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);
+ __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);
+ __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);
+ __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);
+ __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);
+ __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);
+ __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);
+ __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);
+ __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);
+ __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);
+ __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);
+ __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);
+ __m256i vec_a_1 = vec_as[k * 4 + 1];
+ __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);
+ __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);
+ __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);
+ __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);
+ __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);
+ __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);
+ __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);
+ __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);
+ __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);
+ __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);
+ __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);
+ __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);
+ __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);
+ __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);
+ __m256i vec_a_2 = vec_as[k * 4 + 2];
+ __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);
+ __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);
+ __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);
+ __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);
+ __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);
+ __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);
+ __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);
+ __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);
+ __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);
+ __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);
+ __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);
+ __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);
+ __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);
+ __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);
+ __m256i vec_a_3 = vec_as[k * 4 + 3];
+ __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);
+ __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);
+ __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);
+ __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);
+ __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);
+ __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);
+ __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);
+ __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);
+ __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);
+ __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);
+ __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);
+ __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);
+ __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);
+ __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);
+ }
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs));
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs), vec_gc3);
+ }
+ }
+#endif
+}
+
+template
+inline int32_t two_tbl_impl3200_8640(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const int KK = BK2 / 2;
+#pragma unroll
+ for (int i = 0; i < BM3200_8640; i += 32) {
+ __m256i vec_as[KK / 2];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ __m256i vec_a = vec_as[k * 4 + j];
+
+ __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));
+ __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));
+ __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));
+ __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));
+
+ __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);
+ __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);
+ __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);
+
+ __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);
+ __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);
+ __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);
+
+ __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo);
+ }
+ }
+
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs));
+
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_8640 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_8640 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_8640 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_8640 * bs), vec_gc3);
+ }
+ }
+#endif
+ return 0;
+}
+
+template
+int32_t three_qgemm_lut_3200_8640(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_8640];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_8640 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 8640 / BBK3200_8640; ++k_outer) {
+ three_tbl_impl_3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_8640 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_8640 / 3 / 2 * BM3200_8640)])), (&(((uint8_t*)sign)[(k_outer * BBK3200_8640 / 3 / 8 * BM3200_8640)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM3200_8640; i++) {
+ ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM3200_8640]);
+ }
+ }
+ return 0;
+}
+
+template
+int32_t two_qgemm_lut_3200_8640(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_8640];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_8640 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 0 / 32; ++k_outer) {
+ two_tbl_impl3200_8640((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM3200_8640)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM3200_8640; i++) {
+ ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM3200_8640]);
+ ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];
+ }
+ }
+ return 0;
+}
+
+#include
+
+#define BM3200_3200 320
+#define BBK3200_3200 96
+template
+inline void three_tbl_impl_3200_3200(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);
+ const __m256i vec_zero = _mm256_set1_epi8(0x00);
+ const __m256i vec_one = _mm256_set1_epi8(0xff);
+ const int KK = BBK3200_3200 / 3;
+#pragma unroll
+ for (int i = 0; i < BM3200_3200; i += 32) {
+ __m256i vec_as[KK / 2];
+ __m256i vec_signs[KK / 8];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+ #pragma unroll
+ for (int as = 0; as < KK / 8; as++) {
+ vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ __m256i vec_sign = vec_signs[k];
+ __m256i vec_a_0 = vec_as[k * 4 + 0];
+ __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);
+ __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);
+ __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);
+ __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);
+ __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);
+ __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);
+ __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);
+ __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);
+ __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);
+ __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);
+ __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);
+ __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);
+ __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);
+ __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);
+ __m256i vec_a_1 = vec_as[k * 4 + 1];
+ __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);
+ __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);
+ __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);
+ __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);
+ __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);
+ __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);
+ __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);
+ __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);
+ __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);
+ __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);
+ __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);
+ __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);
+ __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);
+ __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);
+ __m256i vec_a_2 = vec_as[k * 4 + 2];
+ __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);
+ __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);
+ __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);
+ __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);
+ __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);
+ __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);
+ __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);
+ __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);
+ __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);
+ __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);
+ __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);
+ __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);
+ __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);
+ __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);
+ __m256i vec_a_3 = vec_as[k * 4 + 3];
+ __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);
+ __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);
+ __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);
+ __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);
+ __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);
+ __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);
+ __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);
+ __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);
+ __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);
+ __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);
+ __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);
+ __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);
+ __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);
+ __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);
+ }
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs));
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs), vec_gc3);
+ }
+ }
+#endif
+}
+
+template
+inline int32_t two_tbl_impl3200_3200(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const int KK = BK2 / 2;
+#pragma unroll
+ for (int i = 0; i < BM3200_3200; i += 32) {
+ __m256i vec_as[KK / 2];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ __m256i vec_a = vec_as[k * 4 + j];
+
+ __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));
+ __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));
+ __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));
+ __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));
+
+ __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);
+ __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);
+ __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);
+
+ __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);
+ __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);
+ __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);
+
+ __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo);
+ }
+ }
+
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs));
+
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM3200_3200 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM3200_3200 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM3200_3200 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM3200_3200 * bs), vec_gc3);
+ }
+ }
+#endif
+ return 0;
+}
+
+template
+int32_t three_qgemm_lut_3200_3200(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_3200];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_3200 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 3168 / BBK3200_3200; ++k_outer) {
+ three_tbl_impl_3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK3200_3200 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK3200_3200 / 3 / 2 * BM3200_3200)])), (&(((uint8_t*)sign)[(k_outer * BBK3200_3200 / 3 / 8 * BM3200_3200)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM3200_3200; i++) {
+ ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM3200_3200]);
+ }
+ }
+ return 0;
+}
+
+template
+int32_t two_qgemm_lut_3200_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM3200_3200];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM3200_3200 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 32 / 32; ++k_outer) {
+ two_tbl_impl3200_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM3200_3200)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM3200_3200; i++) {
+ ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM3200_3200]);
+ ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];
+ }
+ }
+ return 0;
+}
+
+#include
+
+#define BM8640_3200 320
+#define BBK8640_3200 96
+template
+inline void three_tbl_impl_8640_3200(int32_t* c, int8_t* lut, uint8_t* a, uint8_t* sign) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const __m256i vec_sign_mask = _mm256_set1_epi16(0x8000);
+ const __m256i vec_zero = _mm256_set1_epi8(0x00);
+ const __m256i vec_one = _mm256_set1_epi8(0xff);
+ const int KK = BBK8640_3200 / 3;
+#pragma unroll
+ for (int i = 0; i < BM8640_3200; i += 32) {
+ __m256i vec_as[KK / 2];
+ __m256i vec_signs[KK / 8];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+ #pragma unroll
+ for (int as = 0; as < KK / 8; as++) {
+ vec_signs[as] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(sign + i * KK / 8 + as * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ __m256i vec_sign = vec_signs[k];
+ __m256i vec_a_0 = vec_as[k * 4 + 0];
+ __m128i vec_k1_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_0 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 0 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0)), 15);
+ __m256i vec_sign_left_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 1)), 15);
+ __m256i vec_v_top_0 = _mm256_and_si256(_mm256_srli_epi16(vec_a_0, 4), vec_mask);
+ __m256i vec_v_top_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_0, vec_k1_0), vec_v_top_0);
+ __m256i vec_v_top_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_0, vec_k2_0), vec_v_top_0);
+ __m256i vec_sign_right_hi_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 2)), 15);
+ __m256i vec_sign_right_lo_0 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 0 + 3)), 15);
+ __m256i vec_v_bot_0 = _mm256_and_si256(vec_a_0, vec_mask);
+ __m256i vec_v_bot_fir_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_0, vec_k3_0), vec_v_bot_0);
+ __m256i vec_v_bot_sec_0 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_0, vec_k4_0), vec_v_bot_0);
+ __m256i vec_v_top_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_lo_0), vec_sign_left_lo_0);
+ __m256i vec_v_top_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_0, vec_v_top_sec_0), vec_sign_left_hi_0), vec_sign_left_hi_0);
+ __m256i vec_v_bot_lo_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_lo_0), vec_sign_right_lo_0);
+ __m256i vec_v_bot_hi_0 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_0, vec_v_bot_sec_0), vec_sign_right_hi_0), vec_sign_right_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_0);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_0);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_0);
+ __m256i vec_a_1 = vec_as[k * 4 + 1];
+ __m128i vec_k1_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 1 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1)), 15);
+ __m256i vec_sign_left_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 1)), 15);
+ __m256i vec_v_top_1 = _mm256_and_si256(_mm256_srli_epi16(vec_a_1, 4), vec_mask);
+ __m256i vec_v_top_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_1, vec_k1_1), vec_v_top_1);
+ __m256i vec_v_top_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_1, vec_k2_1), vec_v_top_1);
+ __m256i vec_sign_right_hi_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 2)), 15);
+ __m256i vec_sign_right_lo_1 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 1 + 3)), 15);
+ __m256i vec_v_bot_1 = _mm256_and_si256(vec_a_1, vec_mask);
+ __m256i vec_v_bot_fir_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_1, vec_k3_1), vec_v_bot_1);
+ __m256i vec_v_bot_sec_1 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_1, vec_k4_1), vec_v_bot_1);
+ __m256i vec_v_top_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_lo_1), vec_sign_left_lo_1);
+ __m256i vec_v_top_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_1, vec_v_top_sec_1), vec_sign_left_hi_1), vec_sign_left_hi_1);
+ __m256i vec_v_bot_lo_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_lo_1), vec_sign_right_lo_1);
+ __m256i vec_v_bot_hi_1 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_1, vec_v_bot_sec_1), vec_sign_right_hi_1), vec_sign_right_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_1);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_1);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_1);
+ __m256i vec_a_2 = vec_as[k * 4 + 2];
+ __m128i vec_k1_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 2 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2)), 15);
+ __m256i vec_sign_left_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 1)), 15);
+ __m256i vec_v_top_2 = _mm256_and_si256(_mm256_srli_epi16(vec_a_2, 4), vec_mask);
+ __m256i vec_v_top_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_2, vec_k1_2), vec_v_top_2);
+ __m256i vec_v_top_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_2, vec_k2_2), vec_v_top_2);
+ __m256i vec_sign_right_hi_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 2)), 15);
+ __m256i vec_sign_right_lo_2 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 2 + 3)), 15);
+ __m256i vec_v_bot_2 = _mm256_and_si256(vec_a_2, vec_mask);
+ __m256i vec_v_bot_fir_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_2, vec_k3_2), vec_v_bot_2);
+ __m256i vec_v_bot_sec_2 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_2, vec_k4_2), vec_v_bot_2);
+ __m256i vec_v_top_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_lo_2), vec_sign_left_lo_2);
+ __m256i vec_v_top_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_2, vec_v_top_sec_2), vec_sign_left_hi_2), vec_sign_left_hi_2);
+ __m256i vec_v_bot_lo_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_lo_2), vec_sign_right_lo_2);
+ __m256i vec_v_bot_hi_2 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_2, vec_v_bot_sec_2), vec_sign_right_hi_2), vec_sign_right_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_2);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_2);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_2);
+ __m256i vec_a_3 = vec_as[k * 4 + 3];
+ __m128i vec_k1_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 0 + K3 / 3 * 32 * bs));
+ __m128i vec_k2_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 16 + K3 / 3 * 32 * bs));
+ __m128i vec_k3_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 32 + K3 / 3 * 32 * bs));
+ __m128i vec_k4_3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + 3 * 64 + 48 + K3 / 3 * 32 * bs));
+ __m256i vec_sign_left_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3)), 15);
+ __m256i vec_sign_left_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 1)), 15);
+ __m256i vec_v_top_3 = _mm256_and_si256(_mm256_srli_epi16(vec_a_3, 4), vec_mask);
+ __m256i vec_v_top_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1_3, vec_k1_3), vec_v_top_3);
+ __m256i vec_v_top_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2_3, vec_k2_3), vec_v_top_3);
+ __m256i vec_sign_right_hi_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 2)), 15);
+ __m256i vec_sign_right_lo_3 = _mm256_srai_epi16(_mm256_slli_epi16(vec_sign, (4 * 3 + 3)), 15);
+ __m256i vec_v_bot_3 = _mm256_and_si256(vec_a_3, vec_mask);
+ __m256i vec_v_bot_fir_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3_3, vec_k3_3), vec_v_bot_3);
+ __m256i vec_v_bot_sec_3 = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4_3, vec_k4_3), vec_v_bot_3);
+ __m256i vec_v_top_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_lo_3), vec_sign_left_lo_3);
+ __m256i vec_v_top_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_top_fir_3, vec_v_top_sec_3), vec_sign_left_hi_3), vec_sign_left_hi_3);
+ __m256i vec_v_bot_lo_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpackhi_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_lo_3), vec_sign_right_lo_3);
+ __m256i vec_v_bot_hi_3 = _mm256_xor_si256(_mm256_add_epi16(_mm256_unpacklo_epi8(vec_v_bot_fir_3, vec_v_bot_sec_3), vec_sign_right_hi_3), vec_sign_right_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi_3);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo_3);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo_3);
+ }
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs));
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs), vec_gc3);
+ }
+ }
+#endif
+}
+
+template
+inline int32_t two_tbl_impl8640_3200(int32_t* c, int8_t* lut, uint8_t* a) {
+#ifdef __AVX2__
+ const __m256i vec_mask = _mm256_set1_epi8(0x0f);
+ const int KK = BK2 / 2;
+#pragma unroll
+ for (int i = 0; i < BM8640_3200; i += 32) {
+ __m256i vec_as[KK / 2];
+ #pragma unroll
+ for (int ai = 0; ai < KK / 2; ai++) {
+ vec_as[ai] = _mm256_loadu_si256(reinterpret_cast<__m256i*>(a + i * KK / 2 + ai * 32));
+ }
+#pragma unroll
+ for (int bs = 0; bs < batch_size; bs++) {
+ __m256i vec_c0 = _mm256_setzero_si256();
+ __m256i vec_c1 = _mm256_setzero_si256();
+#pragma unroll
+ for (int k = 0; k < KK / 8; k++) {
+ #pragma unroll
+ for (int j = 0; j < 4; j++) {
+ __m256i vec_a = vec_as[k * 4 + j];
+
+ __m128i vec_k1 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 0 + K2 / 2 * 32 * bs));
+ __m128i vec_k2 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 16 + K2 / 2 * 32 * bs));
+ __m128i vec_k3 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 32 + K2 / 2 * 32 * bs));
+ __m128i vec_k4 = _mm_loadu_si128(reinterpret_cast<__m128i*>(lut + k * 32 * 8 + j * 64 + 48 + K2 / 2 * 32 * bs));
+
+ __m256i vec_v_top = _mm256_and_si256(_mm256_srli_epi16(vec_a, 4), vec_mask);
+ __m256i vec_v_top_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k1, vec_k1), vec_v_top);
+ __m256i vec_v_top_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k2, vec_k2), vec_v_top);
+
+ __m256i vec_v_bot = _mm256_and_si256(vec_a, vec_mask);
+ __m256i vec_v_bot_fir = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k3, vec_k3), vec_v_bot);
+ __m256i vec_v_bot_sec = _mm256_shuffle_epi8(_mm256_set_m128i(vec_k4, vec_k4), vec_v_bot);
+
+ __m256i vec_v_top_lo = _mm256_unpackhi_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_top_hi = _mm256_unpacklo_epi8(vec_v_top_fir, vec_v_top_sec);
+ __m256i vec_v_bot_lo = _mm256_unpackhi_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ __m256i vec_v_bot_hi = _mm256_unpacklo_epi8(vec_v_bot_fir, vec_v_bot_sec);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_top_hi);
+ vec_c0 = _mm256_add_epi16(vec_c0, vec_v_bot_hi);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_top_lo);
+ vec_c1 = _mm256_add_epi16(vec_c1, vec_v_bot_lo);
+ }
+ }
+
+ __m256i vec_gc0 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs));
+ __m256i vec_gc1 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs));
+ __m256i vec_gc2 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs));
+ __m256i vec_gc3 = _mm256_loadu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs));
+
+ vec_gc0 = _mm256_add_epi32(vec_gc0, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c0)));
+ vec_gc1 = _mm256_add_epi32(vec_gc1, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c0, 1)));
+ vec_gc2 = _mm256_add_epi32(vec_gc2, _mm256_cvtepi16_epi32(_mm256_castsi256_si128(vec_c1)));
+ vec_gc3 = _mm256_add_epi32(vec_gc3, _mm256_cvtepi16_epi32(_mm256_extracti128_si256(vec_c1, 1)));
+
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + BM8640_3200 * bs), vec_gc0);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 8 + BM8640_3200 * bs), vec_gc1);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 16 + BM8640_3200 * bs), vec_gc2);
+ _mm256_storeu_si256(reinterpret_cast<__m256i*>(c + i + 24 + BM8640_3200 * bs), vec_gc3);
+ }
+ }
+#endif
+ return 0;
+}
+
+template
+int32_t three_qgemm_lut_8640_3200(void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM8640_3200];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM8640_3200 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 3168 / BBK8640_3200; ++k_outer) {
+ three_tbl_impl_8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BBK8640_3200 / 3 * 32)])), (&(((uint8_t*)A)[(k_outer * BBK8640_3200 / 3 / 2 * BM8640_3200)])), (&(((uint8_t*)sign)[(k_outer * BBK8640_3200 / 3 / 8 * BM8640_3200)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM8640_3200; i++) {
+ ((int32_t*)C)[i] = (int32_t)(((int32_t*)CBits)[i + bs * BM8640_3200]);
+ }
+ }
+ return 0;
+}
+
+template
+int32_t two_qgemm_lut_8640_3200(void* A, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ alignas(32) uint32_t CBits[BATCH_SIZE * BM8640_3200];
+ memset(&(CBits[0]), 0, BATCH_SIZE * BM8640_3200 * sizeof(int32_t));
+#pragma unroll
+ for (int32_t k_outer = 0; k_outer < 32 / 32; ++k_outer) {
+ two_tbl_impl8640_3200((&(((int32_t*)CBits)[0])), (&(((int8_t*)LUT)[(k_outer * BK2 / 2 * 32)])), (&(((uint8_t*)A)[(k_outer * BK2 / 2 / 2 * BM8640_3200)])));
+ }
+#pragma unroll
+ for (int bs = 0; bs < BATCH_SIZE; bs++) {
+#pragma unroll
+ for (int i = 0; i < BM8640_3200; i++) {
+ ((int32_t*)C)[i] += (int32_t)(((int32_t*)CBits)[i + bs * BM8640_3200]);
+ ((float*)C)[i] = (float)(((int32_t*)C)[i]) / ((float*)LUT_Scales)[bs] * ((float*)Scales)[0];
+ }
+ }
+ return 0;
+}
+
+void ggml_preprocessor(int bs, int m, int three_k, int two_k, void* B, void* LUT_Scales, void* Three_QLUT, void* Two_QLUT) {
+ partial_max_reset(bs, (&(((float*)LUT_Scales)[0])));
+ if (m == 3200 && two_k == 0 && three_k == 8640) {
+ for (int32_t b = 0; b < bs; b++) {
+ per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));
+ three_lut_ctor<8640>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));
+ two_lut_ctor<0>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 8640])), (&(((float*)LUT_Scales)[b])));
+ }
+ }
+ else if (m == 3200 && two_k == 32 && three_k == 3168) {
+ for (int32_t b = 0; b < bs; b++) {
+ per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));
+ three_lut_ctor<3168>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));
+ two_lut_ctor<32>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 3168])), (&(((float*)LUT_Scales)[b])));
+ }
+ }
+ else if (m == 8640 && two_k == 32 && three_k == 3168) {
+ for (int32_t b = 0; b < bs; b++) {
+ per_tensor_quant(two_k + three_k, (&(((float*)LUT_Scales)[b])), (&(((float*)B)[b * (two_k + three_k)])));
+ three_lut_ctor<3168>((&(((int8_t*)Three_QLUT)[b * three_k / 3 * 32])), (&(((float*)B)[b * (three_k + two_k)])), (&(((float*)LUT_Scales)[b])));
+ two_lut_ctor<32>((&(((int8_t*)Two_QLUT)[b * two_k / 2 * 32])), (&(((float*)B)[b * (three_k + two_k) + 3168])), (&(((float*)LUT_Scales)[b])));
+ }
+ }
+}
+void ggml_qgemm_lut(int bs, int m, int k, int BK, void* A, void* sign, void* LUT, void* Scales, void* LUT_Scales, void* C) {
+ if (m == 3200 && k == 8640) {
+ if (BK == 0) {
+ if (bs == 1) {
+ two_qgemm_lut_3200_8640<1>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 8) {
+ two_qgemm_lut_3200_8640<8>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 32) {
+ two_qgemm_lut_3200_8640<32>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 128) {
+ two_qgemm_lut_3200_8640<128>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 256) {
+ two_qgemm_lut_3200_8640<256>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 512) {
+ two_qgemm_lut_3200_8640<512>(A, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ else if (BK == 8640) {
+ if (bs == 1) {
+ three_qgemm_lut_3200_8640<1>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 8) {
+ three_qgemm_lut_3200_8640<8>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 32) {
+ three_qgemm_lut_3200_8640<32>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 128) {
+ three_qgemm_lut_3200_8640<128>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 256) {
+ three_qgemm_lut_3200_8640<256>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 512) {
+ three_qgemm_lut_3200_8640<512>(A, sign, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ }
+ else if (m == 3200 && k == 3200) {
+ if (BK == 32) {
+ if (bs == 1) {
+ two_qgemm_lut_3200_3200<1>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 8) {
+ two_qgemm_lut_3200_3200<8>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 32) {
+ two_qgemm_lut_3200_3200<32>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 128) {
+ two_qgemm_lut_3200_3200<128>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 256) {
+ two_qgemm_lut_3200_3200<256>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 512) {
+ two_qgemm_lut_3200_3200<512>(A, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ else if (BK == 3168) {
+ if (bs == 1) {
+ three_qgemm_lut_3200_3200<1>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 8) {
+ three_qgemm_lut_3200_3200<8>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 32) {
+ three_qgemm_lut_3200_3200<32>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 128) {
+ three_qgemm_lut_3200_3200<128>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 256) {
+ three_qgemm_lut_3200_3200<256>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 512) {
+ three_qgemm_lut_3200_3200<512>(A, sign, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ }
+ else if (m == 8640 && k == 3200) {
+ if (BK == 32) {
+ if (bs == 1) {
+ two_qgemm_lut_8640_3200<1>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 8) {
+ two_qgemm_lut_8640_3200<8>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 32) {
+ two_qgemm_lut_8640_3200<32>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 128) {
+ two_qgemm_lut_8640_3200<128>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 256) {
+ two_qgemm_lut_8640_3200<256>(A, LUT, Scales, LUT_Scales, C);
+ } else if (bs == 512) {
+ two_qgemm_lut_8640_3200<512>(A, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ else if (BK == 3168) {
+ if (bs == 1) {
+ three_qgemm_lut_8640_3200<1>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 8) {
+ three_qgemm_lut_8640_3200<8>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 32) {
+ three_qgemm_lut_8640_3200<32>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 128) {
+ three_qgemm_lut_8640_3200<128>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 256) {
+ three_qgemm_lut_8640_3200<256>(A, sign, LUT, Scales, LUT_Scales, C);
+ }else if (bs == 512) {
+ three_qgemm_lut_8640_3200<512>(A, sign, LUT, Scales, LUT_Scales, C);
+ }
+ }
+ }
+}
+
+void ggml_bitnet_transform_tensor(struct ggml_tensor * tensor) {
+ if (!(is_type_supported(tensor->type) && tensor->backend == GGML_BACKEND_TYPE_CPU && tensor->extra == nullptr)) {
+ return;
+ }
+
+ int k = tensor->ne[0];
+ int m = tensor->ne[1];
+ const int lut_scales_size = 1;
+ int bk = 0;
+ int bm = 0;
+
+ if (m == 3200 && k == 8640) {
+ bm = BM3200_8640;
+ bk = BBK3200_8640;
+ }
+else if (m == 3200 && k == 3200) {
+ bm = BM3200_3200;
+ bk = BBK3200_3200;
+ }
+else if (m == 8640 && k == 3200) {
+ bm = BM8640_3200;
+ bk = BBK8640_3200;
+ }
+
+ const int n_tile_num = m / bm;
+ const int BK = bk;
+ uint8_t * qweights;
+ bitnet_float_type * scales;
+
+ scales = (bitnet_float_type *) aligned_malloc(sizeof(bitnet_float_type));
+ qweights = (uint8_t *) tensor->data;
+ int nbytes = (k - 256) * m / 3 * 5 / 8 + 256 * m / 2 * 4 / 8;
+ if (nbytes % 32 != 0) nbytes = 32 - nbytes % 32 + nbytes;
+ float * i2_scales = (float * )(qweights + nbytes);
+ scales[0] = (bitnet_float_type) i2_scales[0];
+
+ tensor->extra = bitnet_tensor_extras + bitnet_tensor_extras_index;
+ bitnet_tensor_extras[bitnet_tensor_extras_index++] = {
+ /* .lut_scales_size = */ lut_scales_size,
+ /* .BK = */ BK,
+ /* .n_tile_num = */ n_tile_num,
+ /* .qweights = */ qweights,
+ /* .scales = */ scales
+ };
+}
+#endif
\ No newline at end of file
diff --git a/include/kernel_config.ini b/include/kernel_config.ini
new file mode 100644
index 000000000..46468bf13
--- /dev/null
+++ b/include/kernel_config.ini
@@ -0,0 +1,21 @@
+[Kernels_0]
+m = 3200
+k = 8640
+bm = 160
+bk = 96
+bmm = 32
+
+[Kernels_1]
+m = 3200
+k = 3200
+bm = 320
+bk = 96
+bmm = 32
+
+[Kernels_2]
+m = 8640
+k = 3200
+bm = 320
+bk = 96
+bmm = 32
+
diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp
index 4ba9d6509..019e25c7c 100644
--- a/src/ggml-bitnet-mad.cpp
+++ b/src/ggml-bitnet-mad.cpp
@@ -808,7 +808,7 @@ void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size
accu[iy] = _mm256_setzero_si256();
}
- int8_t * y_col = y + col * by;
+ const int8_t * y_col = y + col * by;
for (int i = 0; i < group32_num; i++) {
const uint8_t *px = x + i * 1024;
@@ -1053,4 +1053,4 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b
{
ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc);
}
-}
\ No newline at end of file
+}
diff --git a/utils/convert-helper-bitnet.py b/utils/convert-helper-bitnet.py
index 9ed8db013..f6e95862a 100644
--- a/utils/convert-helper-bitnet.py
+++ b/utils/convert-helper-bitnet.py
@@ -36,6 +36,8 @@ def main():
convert_script = utils_dir / "convert-ms-to-gguf-bitnet.py"
llama_quantize_binary = project_root_dir / "build" / "bin" / "llama-quantize"
+ if sys.platform == "win32": # Add Windows Path support
+ llama_quantize_binary = llama_quantize_binary.with_suffix(".exe")
input_file = model_dir / "model.safetensors"
input_backup_file = model_dir / "model.safetensors.backup"
@@ -131,4 +133,4 @@ def main():
print(f"Warning: Could not remove backup {input_backup_file}: {e}")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()