Initial commit
This commit is contained in:
commit
26003ed956
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
/target
|
||||||
|
**/*.rs.bk
|
21
Cargo.lock
generated
Normal file
21
Cargo.lock
generated
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
[[package]]
|
||||||
|
name = "cc"
|
||||||
|
version = "1.0.36"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "cuda-test"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
"libc 0.2.53 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libc"
|
||||||
|
version = "0.2.53"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
|
||||||
|
[metadata]
|
||||||
|
"checksum cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a0c56216487bb80eec9c4516337b2588a4f2a2290d72a1416d930e4dcdb0c90d"
|
||||||
|
"checksum libc 0.2.53 (registry+https://github.com/rust-lang/crates.io-index)" = "ec350a9417dfd244dc9a6c4a71e13895a4db6b92f0b106f07ebbc3f3bc580cee"
|
16
Cargo.toml
Normal file
16
Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
[package]
|
||||||
|
name = "cuda-test"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = ["xeals <xeals@pm.me>"]
|
||||||
|
build = "build.rs"
|
||||||
|
edition = "2018"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
libc = "0.2.53"
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
cc = "1.0.36"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "cuda-test"
|
||||||
|
path = "src/rust/main.rs"
|
9
build.rs
Normal file
9
build.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
fn main() {
|
||||||
|
cc::Build::new()
|
||||||
|
.cuda(true)
|
||||||
|
.flag("-cudart=shared")
|
||||||
|
.file("src/c/test.cu")
|
||||||
|
.compile("libtest.a");
|
||||||
|
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
|
||||||
|
println!("cargo:rustc-link-lib=cudart");
|
||||||
|
}
|
38
src/c/test.cu
Normal file
38
src/c/test.cu
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#include <stdlib.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include "kernel.cu"
|
||||||
|
|
||||||
|
#define CUDA_ERRCHK(fn) { __gpucheck((fn), __FILE__, __LINE__); }
|
||||||
|
|
||||||
|
inline cudaError_t __gpucheck(cudaError_t code, const char *file, int line) {
|
||||||
|
if (code != cudaSuccess) {
|
||||||
|
fprintf(stderr, "CUDA runtime error [%s:%d]: %s\n", file, line, cudaGetErrorString(code));
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
return code;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__
|
||||||
|
void mulAll_kernel(int *out, int *in, int n, size_t size) {
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
i < size;
|
||||||
|
i += blockDim.x + gridDim.x) {
|
||||||
|
out[i] = in[i] * n;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
void mulAll(int *out, const int *input, int n, size_t size) {
|
||||||
|
int *out_d, *in_d, i_d, size_d;
|
||||||
|
|
||||||
|
CUDA_ERRCHK(cudaMalloc((void **) &in_d, sizeof(int) * size));
|
||||||
|
CUDA_ERRCHK(cudaMemcpy(in_d, input, sizeof(int) * size, cudaMemcpyHostToDevice));
|
||||||
|
|
||||||
|
CUDA_ERRCHK(cudaMalloc((void **) &out_d, sizeof(int) * size));
|
||||||
|
mulAll_kernel << < 32, 32 >> > (out_d, in_d, n, size);
|
||||||
|
CUDA_ERRCHK(cudaMemcpy(out, out_d, sizeof(int) * size, cudaMemcpyDeviceToHost));
|
||||||
|
|
||||||
|
CUDA_ERRCHK(cudaFree(in_d));
|
||||||
|
CUDA_ERRCHK(cudaFree(out_d));
|
||||||
|
}
|
||||||
|
}
|
34
src/rust/main.rs
Normal file
34
src/rust/main.rs
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
mod cuda {
|
||||||
|
use std::os::raw::c_int;
|
||||||
|
|
||||||
|
use libc::size_t;
|
||||||
|
|
||||||
|
#[link(name = "test", kind = "static")]
|
||||||
|
extern {
|
||||||
|
fn mulAll(out: *const c_int, input: *const c_int, n: c_int, size: size_t);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul_by(src: &[i32], by: i32) -> Vec<i32> {
|
||||||
|
unsafe {
|
||||||
|
let len = src.len() as size_t;
|
||||||
|
let psrc = src.as_ptr();
|
||||||
|
|
||||||
|
let mut res = Vec::with_capacity(src.len());
|
||||||
|
let pres = res.as_mut_ptr();
|
||||||
|
|
||||||
|
mulAll(pres, psrc, by as c_int, len);
|
||||||
|
|
||||||
|
// Turns out converting to a raw pointer drops the length information.
|
||||||
|
res.set_len(src.len());
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let v = (1..128).collect::<Vec<_>>();
|
||||||
|
let n = 3;
|
||||||
|
|
||||||
|
let o = cuda::mul_by(&v, n);
|
||||||
|
assert_eq!(o, v.iter().map(|i| i * n).collect::<Vec<_>>(), "output mangled somewhere");
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user