mirror of
https://github.com/fumiama/gozel.git
synced 2026-06-05 00:10:24 +08:00
feat(example): impl. vadd
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -28,8 +28,8 @@ go.work.sum
|
||||
.env
|
||||
|
||||
# Editor/IDE
|
||||
# .idea/
|
||||
# .vscode/
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# Local spec
|
||||
/spec
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/fumiama/gozel/ze"
|
||||
)
|
||||
|
||||
func main() {
|
||||
hs, err := ze.InitGPUDrivers()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fmt.Println(hs)
|
||||
}
|
||||
1
cmd/examples/vadd/.gitignore
vendored
Normal file
1
cmd/examples/vadd/.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
/device*
|
||||
9
cmd/examples/vadd/main.cpp
Normal file
9
cmd/examples/vadd/main.cpp
Normal file
@@ -0,0 +1,9 @@
|
||||
#include <sycl/sycl.hpp>
|
||||
|
||||
extern "C" SYCL_EXTERNAL
|
||||
void vector_add(double* a, double* b) {
|
||||
auto item = sycl::ext::oneapi::this_work_item::get_nd_item<1>();
|
||||
int idx = item.get_global_id(0);
|
||||
|
||||
a[idx] += b[idx];
|
||||
}
|
||||
186
cmd/examples/vadd/main.go
Normal file
186
cmd/examples/vadd/main.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
"github.com/fumiama/gozel/ze"
|
||||
)
|
||||
|
||||
//go:generate clang++ -fsycl -fsycl-device-only -fno-sycl-use-footer -faddrsig -Xclang -emit-llvm-bc main.cpp -o device_func.bc
|
||||
//go:generate sycl-post-link -symbols -split=auto -o device_func.table device_func.bc
|
||||
//go:generate llvm-spirv -o device_func.spv device_func_0.bc
|
||||
//go:generate clang++ -target spir64-unknown-unknown -S -emit-llvm -x ir device_func_0.bc -o device_func.ll
|
||||
//go:generate go run ../../func2kernel device_func.ll device_kern.ll
|
||||
//go:generate clang++ -target spir64-unknown-unknown -c -emit-llvm -x ir device_kern.ll -o device_kern.bc
|
||||
//go:generate llvm-spirv -o main.spv device_kern.bc
|
||||
//go:generate clang++ -target spir64-unknown-unknown -S -emit-llvm -x ir device_kern.bc -o main.ll
|
||||
|
||||
//go:embed main.spv
|
||||
var kernelspv []byte
|
||||
|
||||
const (
|
||||
X, Y, Z = 1024, 1, 1
|
||||
N = X * Y * Z
|
||||
bufsz = N * unsafe.Sizeof(float64(0))
|
||||
)
|
||||
|
||||
func main() {
|
||||
floatbuf := make([]float64, 2*N)
|
||||
for i := range floatbuf {
|
||||
floatbuf[i] = rand.Float64()
|
||||
}
|
||||
|
||||
gpus, err := ze.InitGPUDrivers()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(gpus) == 0 {
|
||||
panic("no gpu available")
|
||||
}
|
||||
gpu := gpus[0]
|
||||
|
||||
ctx, err := gpu.ContextCreate()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
devs, err := gpu.DeviceGet()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(devs) == 0 {
|
||||
panic("no device available")
|
||||
}
|
||||
dev := devs[0]
|
||||
|
||||
q, err := ctx.CommandQueueCreate(dev)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer q.Destroy()
|
||||
|
||||
hbuf_v1, err := ctx.MemAllocHost(bufsz, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer ctx.MemFree(hbuf_v1)
|
||||
|
||||
hbuf_v2, err := ctx.MemAllocHost(bufsz, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer ctx.MemFree(hbuf_v2)
|
||||
|
||||
dbuf_v1, err := ctx.MemAllocDevice(dev, bufsz, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer ctx.MemFree(dbuf_v1)
|
||||
|
||||
dbuf_v2, err := ctx.MemAllocDevice(dev, bufsz, 1)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer ctx.MemFree(dbuf_v2)
|
||||
|
||||
zev1, zev2 := unsafe.Slice((*float64)(hbuf_v1), N), unsafe.Slice((*float64)(hbuf_v2), N)
|
||||
copy(zev1, floatbuf[:N])
|
||||
copy(zev2, floatbuf[N:])
|
||||
|
||||
mod, err := ctx.ModuleCreate(dev, kernelspv)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer mod.Destroy()
|
||||
|
||||
krn, err := mod.KernelCreate("vector_add")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer krn.Destroy()
|
||||
|
||||
err = krn.SetArgumentValue(0, unsafe.Sizeof(uintptr(0)), unsafe.Pointer(&dbuf_v1))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = krn.SetArgumentValue(1, unsafe.Sizeof(uintptr(0)), unsafe.Pointer(&dbuf_v2))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = krn.SetGroupSize(X, Y, Z)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
lst, err := ctx.CommandListCreate(dev)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer lst.Destroy()
|
||||
|
||||
err = lst.AppendMemoryCopy(dbuf_v1, hbuf_v1, bufsz)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = lst.AppendMemoryCopy(dbuf_v2, hbuf_v2, bufsz)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = lst.AppendBarrier()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = lst.AppendLaunchKernel(krn, &gozel.ZeGroupCount{
|
||||
Groupcountx: 1, Groupcounty: 1, Groupcountz: 1,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = lst.AppendBarrier()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = lst.AppendMemoryCopy(hbuf_v1, dbuf_v1, bufsz)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = lst.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = q.ExecuteCommandLists(lst)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = q.Synchronize()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fail := false
|
||||
for i := range N {
|
||||
expect := floatbuf[i] + floatbuf[N+i]
|
||||
if zev1[i] != expect {
|
||||
fail = true
|
||||
fmt.Printf("[%05d] expect %f = %f + %f, got %f.\n", i, expect, floatbuf[i], floatbuf[N+i], zev1[i])
|
||||
} else {
|
||||
fmt.Printf("[%05d] valid %f = %f + %f, got %f.\n", i, expect, floatbuf[i], floatbuf[N+i], zev1[i])
|
||||
}
|
||||
}
|
||||
|
||||
if fail {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
52
cmd/examples/vadd/main.ll
Normal file
52
cmd/examples/vadd/main.ll
Normal file
@@ -0,0 +1,52 @@
|
||||
; ModuleID = 'device_kern.bc'
|
||||
source_filename = "main.cpp"
|
||||
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
|
||||
target triple = "spir64-unknown-unknown"
|
||||
|
||||
@__spirv_BuiltInGlobalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
|
||||
|
||||
; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite, inaccessiblemem: write)
|
||||
define dso_local spir_kernel void @vector_add(ptr addrspace(1) noundef captures(none) %0, ptr addrspace(1) noundef readonly captures(none) %1) local_unnamed_addr #0 !sycl_used_aspects !8 !sycl_fixed_targets !10 {
|
||||
%3 = load i64, ptr addrspace(1) @__spirv_BuiltInGlobalInvocationId, align 32, !noalias !11
|
||||
%4 = icmp ult i64 %3, 2147483648
|
||||
tail call void @llvm.assume(i1 %4)
|
||||
%5 = getelementptr inbounds nuw double, ptr addrspace(1) %1, i64 %3
|
||||
%6 = load double, ptr addrspace(1) %5, align 8
|
||||
%7 = getelementptr inbounds nuw double, ptr addrspace(1) %0, i64 %3
|
||||
%8 = load double, ptr addrspace(1) %7, align 8
|
||||
%9 = fadd double %8, %6
|
||||
store double %9, ptr addrspace(1) %7, align 8
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write)
|
||||
declare void @llvm.assume(i1 noundef) #1
|
||||
|
||||
attributes #0 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite, inaccessiblemem: write) "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-entry-point" "sycl-module-id"="main.cpp" "sycl-optlevel"="2" }
|
||||
attributes #1 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: write) }
|
||||
|
||||
!llvm.dependent-libraries = !{!0}
|
||||
!llvm.module.flags = !{!1, !2, !3}
|
||||
!opencl.spir.version = !{!4}
|
||||
!spirv.Source = !{!5}
|
||||
!llvm.ident = !{!6}
|
||||
!sycl-esimd-split-status = !{!7}
|
||||
|
||||
!0 = !{!"libcpmt"}
|
||||
!1 = !{i32 1, !"wchar_size", i32 2}
|
||||
!2 = !{i32 1, !"sycl-device", i32 1}
|
||||
!3 = !{i32 7, !"frame-pointer", i32 2}
|
||||
!4 = !{i32 1, i32 2}
|
||||
!5 = !{i32 4, i32 100000}
|
||||
!6 = !{!"clang version 21.0.0git (https://github.com/intel/llvm d5f649b706f63b5c74e1929bc95db8de91085560)"}
|
||||
!7 = !{i8 0}
|
||||
!8 = !{!9}
|
||||
!9 = !{!"fp64", i32 6}
|
||||
!10 = !{}
|
||||
!11 = !{!12, !14, !16}
|
||||
!12 = distinct !{!12, !13, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv: argument 0"}
|
||||
!13 = distinct !{!13, !"_ZN7__spirv29InitSizesSTGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEE8initSizeEv"}
|
||||
!14 = distinct !{!14, !15, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v: argument 0"}
|
||||
!15 = distinct !{!15, !"_ZN7__spirv22initGlobalInvocationIdILi1EN4sycl3_V12idILi1EEEEET0_v"}
|
||||
!16 = distinct !{!16, !17, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv: argument 0"}
|
||||
!17 = distinct !{!17, !"_ZNK4sycl3_V17nd_itemILi1EE13get_global_idEv"}
|
||||
BIN
cmd/examples/vadd/main.spv
Normal file
BIN
cmd/examples/vadd/main.spv
Normal file
Binary file not shown.
28
cmd/func2kernel/main.go
Normal file
28
cmd/func2kernel/main.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
f, err := os.Open(os.Args[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer f.Close()
|
||||
fo, err := os.Create(os.Args[2])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer fo.Close()
|
||||
scan := bufio.NewScanner(f)
|
||||
for scan.Scan() {
|
||||
t := scan.Text()
|
||||
t = strings.ReplaceAll(t, " spir_func ", " spir_kernel ")
|
||||
t = strings.ReplaceAll(t, "ptr addrspace(4)", "ptr addrspace(1)")
|
||||
fo.WriteString(t)
|
||||
fo.WriteString("\n")
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ReturnTypes interface {
|
||||
@@ -18,8 +19,8 @@ type ReturnTypes interface {
|
||||
//go:uintptrescapes
|
||||
func Call[T ReturnTypes](name string, args ...uintptr) (r T, err error) {
|
||||
r1, r2, err := Syscall(name, args...)
|
||||
if err != nil {
|
||||
return
|
||||
if r1 != 0 {
|
||||
err = errors.New("zecall " + name + ": non-zero return value 0x" + strconv.FormatUint(uint64(r1), 16))
|
||||
}
|
||||
k := reflect.TypeOf(r).Kind()
|
||||
switch k {
|
||||
|
||||
91
ze/command.go
Normal file
91
ze/command.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package ze
|
||||
|
||||
import (
|
||||
"math"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
)
|
||||
|
||||
// CommandQueueHandle is a handle to a Level Zero command queue.
|
||||
type CommandQueueHandle gozel.ZeCommandQueueHandle
|
||||
|
||||
// CommandQueueCreate creates a command queue on the given device with default mode and normal priority.
|
||||
func (h ContextHandle) CommandQueueCreate(hDevice gozel.ZeDeviceHandle) (
|
||||
CommandQueueHandle, error,
|
||||
) {
|
||||
var q gozel.ZeCommandQueueHandle
|
||||
_, err := gozel.ZeCommandQueueCreate(gozel.ZeContextHandle(h), hDevice, &gozel.ZeCommandQueueDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
|
||||
Mode: gozel.ZE_COMMAND_QUEUE_MODE_DEFAULT,
|
||||
Priority: gozel.ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
|
||||
}, &q)
|
||||
return CommandQueueHandle(q), err
|
||||
}
|
||||
|
||||
// ExecuteCommandLists submits the command list for execution on the command queue.
|
||||
func (h CommandQueueHandle) ExecuteCommandLists(hCommandList CommandListHandle) error {
|
||||
cl := gozel.ZeCommandListHandle(hCommandList)
|
||||
_, err := gozel.ZeCommandQueueExecuteCommandLists(gozel.ZeCommandQueueHandle(h), 1, &cl, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
// Synchronize blocks the host until all commands in the command queue have completed.
|
||||
func (h CommandQueueHandle) Synchronize() error {
|
||||
_, err := gozel.ZeCommandQueueSynchronize(gozel.ZeCommandQueueHandle(h), math.MaxUint64)
|
||||
return err
|
||||
}
|
||||
|
||||
// Destroy destroys the command queue and releases its resources.
|
||||
func (h CommandQueueHandle) Destroy() error {
|
||||
_, err := gozel.ZeCommandQueueDestroy(gozel.ZeCommandQueueHandle(h))
|
||||
return err
|
||||
}
|
||||
|
||||
// CommandListHandle is a handle to a Level Zero command list.
|
||||
type CommandListHandle gozel.ZeCommandListHandle
|
||||
|
||||
// CommandListCreate creates a command list on the given device.
|
||||
func (h ContextHandle) CommandListCreate(hDevice gozel.ZeDeviceHandle) (
|
||||
CommandListHandle, error,
|
||||
) {
|
||||
var cl gozel.ZeCommandListHandle
|
||||
_, err := gozel.ZeCommandListCreate(gozel.ZeContextHandle(h), hDevice, &gozel.ZeCommandListDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_COMMAND_LIST_DESC,
|
||||
}, &cl)
|
||||
return CommandListHandle(cl), err
|
||||
}
|
||||
|
||||
// AppendLaunchKernel appends a kernel launch command to the command list.
|
||||
func (h CommandListHandle) AppendLaunchKernel(
|
||||
hKernel KernelHandle, pLaunchFuncArgs *gozel.ZeGroupCount,
|
||||
) error {
|
||||
_, err := gozel.ZeCommandListAppendLaunchKernel(gozel.ZeCommandListHandle(h), gozel.ZeKernelHandle(hKernel), pLaunchFuncArgs, 0, 0, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close closes the command list, making it ready for execution.
|
||||
func (h CommandListHandle) Close() error {
|
||||
_, err := gozel.ZeCommandListClose(gozel.ZeCommandListHandle(h))
|
||||
return err
|
||||
}
|
||||
|
||||
// AppendMemoryCopy appends a memory copy command from srcptr to dstptr of the given size.
|
||||
func (h CommandListHandle) AppendMemoryCopy(
|
||||
dstptr unsafe.Pointer, srcptr unsafe.Pointer, size uintptr,
|
||||
) error {
|
||||
_, err := gozel.ZeCommandListAppendMemoryCopy(gozel.ZeCommandListHandle(h), dstptr, srcptr, size, 0, 0, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
// Destroy destroys the command list and releases its resources.
|
||||
func (h CommandListHandle) Destroy() error {
|
||||
_, err := gozel.ZeCommandListDestroy(gozel.ZeCommandListHandle(h))
|
||||
return err
|
||||
}
|
||||
|
||||
// AppendBarrier appends an execution barrier to the command list.
|
||||
func (h CommandListHandle) AppendBarrier() error {
|
||||
_, err := gozel.ZeCommandListAppendBarrier(gozel.ZeCommandListHandle(h), 0, 0, nil)
|
||||
return err
|
||||
}
|
||||
21
ze/context.go
Normal file
21
ze/context.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package ze
|
||||
|
||||
import "github.com/fumiama/gozel"
|
||||
|
||||
// ContextHandle is a handle to a Level Zero context.
|
||||
type ContextHandle gozel.ZeContextHandle
|
||||
|
||||
// ContextCreate creates a new context for the driver.
|
||||
func (h DriverHandle) ContextCreate() (ContextHandle, error) {
|
||||
var ctx gozel.ZeContextHandle
|
||||
_, err := gozel.ZeContextCreate(gozel.ZeDriverHandle(h), &gozel.ZeContextDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_CONTEXT_DESC,
|
||||
}, &ctx)
|
||||
return ContextHandle(ctx), err
|
||||
}
|
||||
|
||||
// Destroy destroys the context and releases its resources.
|
||||
func (h ContextHandle) Destroy() error {
|
||||
_, err := gozel.ZeContextDestroy(gozel.ZeContextHandle(h))
|
||||
return err
|
||||
}
|
||||
21
ze/device.go
Normal file
21
ze/device.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package ze
|
||||
|
||||
import "github.com/fumiama/gozel"
|
||||
|
||||
// DeviceGet retrieves all devices within the driver.
|
||||
func (h DriverHandle) DeviceGet() ([]gozel.ZeDeviceHandle, error) {
|
||||
var count uint32
|
||||
_, err := gozel.ZeDeviceGet(gozel.ZeDriverHandle(h), &count, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if count == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
devices := make([]gozel.ZeDeviceHandle, count)
|
||||
_, err = gozel.ZeDeviceGet(gozel.ZeDriverHandle(h), &count, &devices[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
15
ze/init.go
15
ze/init.go
@@ -1,10 +1,15 @@
|
||||
package ze
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
)
|
||||
|
||||
func initDrivers(flags gozel.ZeInitDriverTypeFlags) ([]gozel.ZeDriverHandle, error) {
|
||||
// DriverHandle is a handle to a Level Zero driver instance.
|
||||
type DriverHandle gozel.ZeDriverHandle
|
||||
|
||||
func initDrivers(flags gozel.ZeInitDriverTypeFlags) ([]DriverHandle, error) {
|
||||
var count uint32
|
||||
desc := &gozel.ZeInitDriverTypeDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC,
|
||||
@@ -17,8 +22,8 @@ func initDrivers(flags gozel.ZeInitDriverTypeFlags) ([]gozel.ZeDriverHandle, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handles := make([]gozel.ZeDriverHandle, count)
|
||||
_, err = gozel.ZeInitDrivers(&count, &handles[0], desc)
|
||||
handles := make([]DriverHandle, count)
|
||||
_, err = gozel.ZeInitDrivers(&count, (*gozel.ZeDriverHandle)(unsafe.Pointer(&handles[0])), desc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -28,13 +33,13 @@ func initDrivers(flags gozel.ZeInitDriverTypeFlags) ([]gozel.ZeDriverHandle, err
|
||||
// InitGPUDrivers calls zeInitDrivers with ZE_INIT_DRIVER_TYPE_FLAG_GPU from ze_loader.dll.
|
||||
// On success pCount contains the number of drivers and phDrivers (if non-nil)
|
||||
// is filled with driver handles.
|
||||
func InitGPUDrivers() ([]gozel.ZeDriverHandle, error) {
|
||||
func InitGPUDrivers() ([]DriverHandle, error) {
|
||||
return initDrivers(gozel.ZE_INIT_DRIVER_TYPE_FLAG_GPU)
|
||||
}
|
||||
|
||||
// InitNPUDrivers calls zeInitDrivers with ZE_INIT_DRIVER_TYPE_FLAG_NPU from ze_loader.dll.
|
||||
// On success pCount contains the number of drivers and phDrivers (if non-nil)
|
||||
// is filled with driver handles.
|
||||
func InitNPUDrivers() ([]gozel.ZeDriverHandle, error) {
|
||||
func InitNPUDrivers() ([]DriverHandle, error) {
|
||||
return initDrivers(gozel.ZE_INIT_DRIVER_TYPE_FLAG_NPU)
|
||||
}
|
||||
|
||||
41
ze/kernel.go
Normal file
41
ze/kernel.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package ze
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
)
|
||||
|
||||
// KernelHandle is a handle to a Level Zero kernel.
|
||||
type KernelHandle gozel.ZeKernelHandle
|
||||
|
||||
// KernelCreate creates a kernel from the module by the given function name.
|
||||
func (h ModuleHandle) KernelCreate(kernelName string) (KernelHandle, error) {
|
||||
b := []byte(kernelName + "\x00")
|
||||
var k gozel.ZeKernelHandle
|
||||
_, err := gozel.ZeKernelCreate(gozel.ZeModuleHandle(h), &gozel.ZeKernelDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_KERNEL_DESC,
|
||||
Pkernelname: &b[0],
|
||||
}, &k)
|
||||
runtime.KeepAlive(b)
|
||||
return KernelHandle(k), err
|
||||
}
|
||||
|
||||
// SetArgumentValue sets the value of a kernel argument at the given index.
|
||||
func (h KernelHandle) SetArgumentValue(argIndex uint32, argSize uintptr, pArgValue unsafe.Pointer) error {
|
||||
_, err := gozel.ZeKernelSetArgumentValue(gozel.ZeKernelHandle(h), argIndex, argSize, pArgValue)
|
||||
return err
|
||||
}
|
||||
|
||||
// SetGroupSize sets the thread group size for the kernel.
|
||||
func (h KernelHandle) SetGroupSize(groupSizeX uint32, groupSizeY uint32, groupSizeZ uint32) error {
|
||||
_, err := gozel.ZeKernelSetGroupSize(gozel.ZeKernelHandle(h), groupSizeX, groupSizeY, groupSizeZ)
|
||||
return err
|
||||
}
|
||||
|
||||
// Destroy destroys the kernel and releases its resources.
|
||||
func (h KernelHandle) Destroy() error {
|
||||
_, err := gozel.ZeKernelDestroy(gozel.ZeKernelHandle(h))
|
||||
return err
|
||||
}
|
||||
35
ze/mem.go
Normal file
35
ze/mem.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ze
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
)
|
||||
|
||||
// MemAllocDevice allocates device memory on the given device with the specified size and alignment.
|
||||
func (h ContextHandle) MemAllocDevice(hDevice gozel.ZeDeviceHandle, size uintptr, alignment uintptr) (
|
||||
unsafe.Pointer, error,
|
||||
) {
|
||||
var p unsafe.Pointer
|
||||
_, err := gozel.ZeMemAllocDevice(gozel.ZeContextHandle(h), &gozel.ZeDeviceMemAllocDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
|
||||
}, size, alignment, hDevice, &p)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// MemAllocHost allocates host memory with the specified size and alignment.
|
||||
func (h ContextHandle) MemAllocHost(size uintptr, alignment uintptr) (
|
||||
unsafe.Pointer, error,
|
||||
) {
|
||||
var p unsafe.Pointer
|
||||
_, err := gozel.ZeMemAllocHost(gozel.ZeContextHandle(h), &gozel.ZeHostMemAllocDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC,
|
||||
}, size, alignment, &p)
|
||||
return p, err
|
||||
}
|
||||
|
||||
// MemFree frees memory previously allocated with MemAllocDevice or MemAllocHost.
|
||||
func (h ContextHandle) MemFree(ptr unsafe.Pointer) error {
|
||||
_, err := gozel.ZeMemFree(gozel.ZeContextHandle(h), ptr)
|
||||
return err
|
||||
}
|
||||
31
ze/module.go
Normal file
31
ze/module.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package ze
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/fumiama/gozel"
|
||||
)
|
||||
|
||||
// ModuleHandle is a handle to a Level Zero module.
|
||||
type ModuleHandle gozel.ZeModuleHandle
|
||||
|
||||
// ModuleCreate creates a module from SPIR-V binary data on the given device.
|
||||
func (h ContextHandle) ModuleCreate(hDevice gozel.ZeDeviceHandle, data []byte) (
|
||||
ModuleHandle, error,
|
||||
) {
|
||||
var m gozel.ZeModuleHandle
|
||||
_, err := gozel.ZeModuleCreate(gozel.ZeContextHandle(h), hDevice, &gozel.ZeModuleDesc{
|
||||
Stype: gozel.ZE_STRUCTURE_TYPE_MODULE_DESC,
|
||||
Format: gozel.ZE_MODULE_FORMAT_IL_SPIRV,
|
||||
Inputsize: uintptr(len(data)),
|
||||
Pinputmodule: &data[0],
|
||||
}, &m, nil)
|
||||
runtime.KeepAlive(data)
|
||||
return ModuleHandle(m), err
|
||||
}
|
||||
|
||||
// Destroy destroys the module and releases its resources.
|
||||
func (h ModuleHandle) Destroy() error {
|
||||
_, err := gozel.ZeModuleDestroy(gozel.ZeModuleHandle(h))
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user