1
0
mirror of https://github.com/fumiama/base16384-sycl.git synced 2026-06-05 00:32:49 +08:00

optimize(test): move test kernels into test class

This commit is contained in:
源文雨
2025-10-09 16:08:09 +08:00
parent cbe9cda397
commit 24ea4ca7bd
6 changed files with 44 additions and 28 deletions

View File

@@ -73,6 +73,7 @@ message(STATUS "Collected libs: ${B14LIBS}")
enable_testing()
add_subdirectory(tests)
message(STATUS "Collected tests: ${B14TESTS}")
foreach(TARGET_NAME ${B14TESTS})
target_link_libraries(${TARGET_NAME} ${B14LIBS})
endforeach()

15
include/test.hpp Normal file
View File

@@ -0,0 +1,15 @@
#ifndef _TEST_KERNELS_H_
#define _TEST_KERNELS_H_
#include <stdint.h>
namespace base16384 {
class test {
public:
// base16384_test_kernels_basic is a demo calculation that implements
// mod, bit, plus and mul calculations.
SYCL_EXTERNAL static uint8_t kernels_basic(uint8_t in);
};
} // namespace base16384
#endif

View File

@@ -1,6 +1,9 @@
#ifndef _XEINFO_HPP_
#define _XEINFO_HPP_
#include <stdint.h>
#include <iomanip>
#include <iostream>
#include <sstream>
#include <string>
@@ -59,7 +62,7 @@ class xeinfo {
const int num_subslices_per_slice;
const int num_eus_per_subslice;
const int num_threads_per_eu;
const int global_mem_size;
const uint64_t global_mem_size;
const int local_mem_size;
const int max_work_group_size;
const std::vector<unsigned long long> sub_group_sizes;
@@ -84,8 +87,9 @@ class xeinfo {
builder << " 每个 XeCore 的硬件线程数: " << num_thread_per_xecore << "\n";
builder << " 每个向量引擎的硬件线程数: " << num_threads_per_eu << "\n";
builder << " 硬件线程总数: " << total_hardware_threads << "\n";
builder << " GPU 内存大小: " << global_mem_size << " 字节\n";
builder << " 每个工作组的共享本地内存: " << local_mem_size << " 字节\n";
builder << " GPU 内存大小: " << global_mem_size << " B (" << std::fixed << std::setprecision(2)
<< (double)global_mem_size / 1024 / 1024 / 1024 << " GB)\n";
builder << " 每个工作组的共享本地内存: " << local_mem_size << " B\n";
builder << " 最大工作组大小: " << max_work_group_size << "\n";
builder << " 支持的子组大小:";
for (size_t i = 0; i < sub_group_sizes.size(); i++) builder << " " << sub_group_sizes[i];

View File

@@ -5,12 +5,12 @@ set(LOCAL_B14LIBS "")
foreach(CPP_FILE ${CPP_FILES})
# name without .cpp
get_filename_component(TARGET_NAME ${CPP_FILE} NAME_WE)
message(STATUS "Add lib: ${TARGET_NAME}")
message(STATUS "Add CPP lib: ${TARGET_NAME}")
add_library(${TARGET_NAME} STATIC ${CPP_FILE})
set_target_properties(${TARGET_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}")
set_target_properties(${TARGET_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
list(APPEND LOCAL_B14LIBS ${TARGET_NAME})
endforeach()

14
libs/test_kernels.cpp Normal file
View File

@@ -0,0 +1,14 @@
#include <stdint.h>
#include <sycl/sycl.hpp>
#include "test.hpp"
SYCL_EXTERNAL uint8_t base16384::test::kernels_basic(uint8_t in) {
in *= in;
in %= 251;
in ^= in >> 2;
in += 17;
in *= 3;
return in ^ (in << 1);
}

View File

@@ -14,6 +14,7 @@
#include <vector>
#include "errors.hpp"
#include "test.hpp"
#include "xeinfo.hpp"
constexpr int iter_count = 65536;
@@ -67,13 +68,7 @@ int main() {
auto start_time = std::chrono::high_resolution_clock::now();
for (int j = 0; j < iter_count; j++) {
for (auto& byte : cpu_data) {
// 复杂计算:多步数学运算组合
uint8_t temp = byte;
temp = (temp * temp) % 251; // 使用质数避免快速收敛
temp = temp ^ (temp >> 2); // 位运算
temp = (temp + 17) % 256; // 加法和模运算
temp = temp * 3 % 256; // 乘法
byte = temp ^ (temp << 1); // 最终位运算
byte = base16384::test::kernels_basic(byte);
}
}
auto end_time = std::chrono::high_resolution_clock::now();
@@ -91,15 +86,8 @@ int main() {
start_time = std::chrono::high_resolution_clock::now();
auto errn = base16384::errors::try_failed([&]() {
for (int j = 0; j < iter_count; j++) {
q.parallel_for(sycl::range<1>(N), [=](sycl::id<1> i) {
// 复杂计算:多步数学运算组合
uint8_t temp = data[i];
temp = (temp * temp) % 251; // 使用质数避免快速收敛
temp = temp ^ (temp >> 2); // 位运算
temp = (temp + 17) % 256; // 加法和模运算
temp = temp * 3 % 256; // 乘法
data[i] = temp ^ (temp << 1); // 最终位运算
});
q.parallel_for(sycl::range<1>(N),
[=](sycl::id<1> i) { data[i] = base16384::test::kernels_basic(data[i]); });
}
q.wait();
});
@@ -120,13 +108,7 @@ int main() {
q.parallel_for(sycl::nd_range<1>(N, work_group_size),
[=](sycl::nd_item<1> item) { // sub-group size
const auto i = item.get_global_id(0);
// 复杂计算:多步数学运算组合
uint8_t temp = data[i];
temp = (temp * temp) % 251; // 使用质数避免快速收敛
temp = temp ^ (temp >> 2); // 位运算
temp = (temp + 17) % 256; // 加法和模运算
temp = temp * 3 % 256; // 乘法
data[i] = temp ^ (temp << 1); // 最终位运算
data[i] = base16384::test::kernels_basic(data[i]);
});
}
q.wait();