mirror of
https://github.com/fumiama/base16384-sycl.git
synced 2026-06-10 21:24:47 +08:00
optimize(test): move test kernels into test class
This commit is contained in:
@@ -73,6 +73,7 @@ message(STATUS "Collected libs: ${B14LIBS}")
|
|||||||
enable_testing()
|
enable_testing()
|
||||||
add_subdirectory(tests)
|
add_subdirectory(tests)
|
||||||
message(STATUS "Collected tests: ${B14TESTS}")
|
message(STATUS "Collected tests: ${B14TESTS}")
|
||||||
|
|
||||||
foreach(TARGET_NAME ${B14TESTS})
|
foreach(TARGET_NAME ${B14TESTS})
|
||||||
target_link_libraries(${TARGET_NAME} ${B14LIBS})
|
target_link_libraries(${TARGET_NAME} ${B14LIBS})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|||||||
15
include/test.hpp
Normal file
15
include/test.hpp
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
#ifndef _TEST_KERNELS_H_
|
||||||
|
#define _TEST_KERNELS_H_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
namespace base16384 {
|
||||||
|
class test {
|
||||||
|
public:
|
||||||
|
// base16384_test_kernels_basic is a demo calculation that implements
|
||||||
|
// mod, bit, plus and mul calculations.
|
||||||
|
SYCL_EXTERNAL static uint8_t kernels_basic(uint8_t in);
|
||||||
|
};
|
||||||
|
} // namespace base16384
|
||||||
|
|
||||||
|
#endif
|
||||||
@@ -1,6 +1,9 @@
|
|||||||
#ifndef _XEINFO_HPP_
|
#ifndef _XEINFO_HPP_
|
||||||
#define _XEINFO_HPP_
|
#define _XEINFO_HPP_
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <iomanip>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
@@ -59,7 +62,7 @@ class xeinfo {
|
|||||||
const int num_subslices_per_slice;
|
const int num_subslices_per_slice;
|
||||||
const int num_eus_per_subslice;
|
const int num_eus_per_subslice;
|
||||||
const int num_threads_per_eu;
|
const int num_threads_per_eu;
|
||||||
const int global_mem_size;
|
const uint64_t global_mem_size;
|
||||||
const int local_mem_size;
|
const int local_mem_size;
|
||||||
const int max_work_group_size;
|
const int max_work_group_size;
|
||||||
const std::vector<unsigned long long> sub_group_sizes;
|
const std::vector<unsigned long long> sub_group_sizes;
|
||||||
@@ -84,8 +87,9 @@ class xeinfo {
|
|||||||
builder << " 每个 XeCore 的硬件线程数: " << num_thread_per_xecore << "\n";
|
builder << " 每个 XeCore 的硬件线程数: " << num_thread_per_xecore << "\n";
|
||||||
builder << " 每个向量引擎的硬件线程数: " << num_threads_per_eu << "\n";
|
builder << " 每个向量引擎的硬件线程数: " << num_threads_per_eu << "\n";
|
||||||
builder << " 硬件线程总数: " << total_hardware_threads << "\n";
|
builder << " 硬件线程总数: " << total_hardware_threads << "\n";
|
||||||
builder << " GPU 内存大小: " << global_mem_size << " 字节\n";
|
builder << " GPU 内存大小: " << global_mem_size << " B (" << std::fixed << std::setprecision(2)
|
||||||
builder << " 每个工作组的共享本地内存: " << local_mem_size << " 字节\n";
|
<< (double)global_mem_size / 1024 / 1024 / 1024 << " GB)\n";
|
||||||
|
builder << " 每个工作组的共享本地内存: " << local_mem_size << " B\n";
|
||||||
builder << " 最大工作组大小: " << max_work_group_size << "\n";
|
builder << " 最大工作组大小: " << max_work_group_size << "\n";
|
||||||
builder << " 支持的子组大小:";
|
builder << " 支持的子组大小:";
|
||||||
for (size_t i = 0; i < sub_group_sizes.size(); i++) builder << " " << sub_group_sizes[i];
|
for (size_t i = 0; i < sub_group_sizes.size(); i++) builder << " " << sub_group_sizes[i];
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ set(LOCAL_B14LIBS "")
|
|||||||
foreach(CPP_FILE ${CPP_FILES})
|
foreach(CPP_FILE ${CPP_FILES})
|
||||||
# name without .cpp
|
# name without .cpp
|
||||||
get_filename_component(TARGET_NAME ${CPP_FILE} NAME_WE)
|
get_filename_component(TARGET_NAME ${CPP_FILE} NAME_WE)
|
||||||
message(STATUS "Add lib: ${TARGET_NAME}")
|
message(STATUS "Add CPP lib: ${TARGET_NAME}")
|
||||||
add_library(${TARGET_NAME} STATIC ${CPP_FILE})
|
add_library(${TARGET_NAME} STATIC ${CPP_FILE})
|
||||||
|
|
||||||
set_target_properties(${TARGET_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}")
|
set_target_properties(${TARGET_NAME} PROPERTIES COMPILE_FLAGS "${COMPILE_FLAGS}")
|
||||||
set_target_properties(${TARGET_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
|
set_target_properties(${TARGET_NAME} PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
|
||||||
|
|
||||||
list(APPEND LOCAL_B14LIBS ${TARGET_NAME})
|
list(APPEND LOCAL_B14LIBS ${TARGET_NAME})
|
||||||
endforeach()
|
endforeach()
|
||||||
|
|
||||||
|
|||||||
14
libs/test_kernels.cpp
Normal file
14
libs/test_kernels.cpp
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <sycl/sycl.hpp>
|
||||||
|
|
||||||
|
#include "test.hpp"
|
||||||
|
|
||||||
|
SYCL_EXTERNAL uint8_t base16384::test::kernels_basic(uint8_t in) {
|
||||||
|
in *= in;
|
||||||
|
in %= 251;
|
||||||
|
in ^= in >> 2;
|
||||||
|
in += 17;
|
||||||
|
in *= 3;
|
||||||
|
return in ^ (in << 1);
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "errors.hpp"
|
#include "errors.hpp"
|
||||||
|
#include "test.hpp"
|
||||||
#include "xeinfo.hpp"
|
#include "xeinfo.hpp"
|
||||||
|
|
||||||
constexpr int iter_count = 65536;
|
constexpr int iter_count = 65536;
|
||||||
@@ -67,13 +68,7 @@ int main() {
|
|||||||
auto start_time = std::chrono::high_resolution_clock::now();
|
auto start_time = std::chrono::high_resolution_clock::now();
|
||||||
for (int j = 0; j < iter_count; j++) {
|
for (int j = 0; j < iter_count; j++) {
|
||||||
for (auto& byte : cpu_data) {
|
for (auto& byte : cpu_data) {
|
||||||
// 复杂计算:多步数学运算组合
|
byte = base16384::test::kernels_basic(byte);
|
||||||
uint8_t temp = byte;
|
|
||||||
temp = (temp * temp) % 251; // 使用质数避免快速收敛
|
|
||||||
temp = temp ^ (temp >> 2); // 位运算
|
|
||||||
temp = (temp + 17) % 256; // 加法和模运算
|
|
||||||
temp = temp * 3 % 256; // 乘法
|
|
||||||
byte = temp ^ (temp << 1); // 最终位运算
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto end_time = std::chrono::high_resolution_clock::now();
|
auto end_time = std::chrono::high_resolution_clock::now();
|
||||||
@@ -91,15 +86,8 @@ int main() {
|
|||||||
start_time = std::chrono::high_resolution_clock::now();
|
start_time = std::chrono::high_resolution_clock::now();
|
||||||
auto errn = base16384::errors::try_failed([&]() {
|
auto errn = base16384::errors::try_failed([&]() {
|
||||||
for (int j = 0; j < iter_count; j++) {
|
for (int j = 0; j < iter_count; j++) {
|
||||||
q.parallel_for(sycl::range<1>(N), [=](sycl::id<1> i) {
|
q.parallel_for(sycl::range<1>(N),
|
||||||
// 复杂计算:多步数学运算组合
|
[=](sycl::id<1> i) { data[i] = base16384::test::kernels_basic(data[i]); });
|
||||||
uint8_t temp = data[i];
|
|
||||||
temp = (temp * temp) % 251; // 使用质数避免快速收敛
|
|
||||||
temp = temp ^ (temp >> 2); // 位运算
|
|
||||||
temp = (temp + 17) % 256; // 加法和模运算
|
|
||||||
temp = temp * 3 % 256; // 乘法
|
|
||||||
data[i] = temp ^ (temp << 1); // 最终位运算
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
q.wait();
|
q.wait();
|
||||||
});
|
});
|
||||||
@@ -120,13 +108,7 @@ int main() {
|
|||||||
q.parallel_for(sycl::nd_range<1>(N, work_group_size),
|
q.parallel_for(sycl::nd_range<1>(N, work_group_size),
|
||||||
[=](sycl::nd_item<1> item) { // sub-group size
|
[=](sycl::nd_item<1> item) { // sub-group size
|
||||||
const auto i = item.get_global_id(0);
|
const auto i = item.get_global_id(0);
|
||||||
// 复杂计算:多步数学运算组合
|
data[i] = base16384::test::kernels_basic(data[i]);
|
||||||
uint8_t temp = data[i];
|
|
||||||
temp = (temp * temp) % 251; // 使用质数避免快速收敛
|
|
||||||
temp = temp ^ (temp >> 2); // 位运算
|
|
||||||
temp = (temp + 17) % 256; // 加法和模运算
|
|
||||||
temp = temp * 3 % 256; // 乘法
|
|
||||||
data[i] = temp ^ (temp << 1); // 最终位运算
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
q.wait();
|
q.wait();
|
||||||
|
|||||||
Reference in New Issue
Block a user