Initial commit
This commit is contained in:
commit
26003ed956
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
/target
|
||||
**/*.rs.bk
|
21
Cargo.lock
generated
Normal file
21
Cargo.lock
generated
Normal file
@ -0,0 +1,21 @@
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[[package]]
|
||||
name = "cuda-test"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"libc 0.2.53 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.53"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
|
||||
[metadata]
|
||||
"checksum cc 1.0.36 (registry+https://github.com/rust-lang/crates.io-index)" = "a0c56216487bb80eec9c4516337b2588a4f2a2290d72a1416d930e4dcdb0c90d"
|
||||
"checksum libc 0.2.53 (registry+https://github.com/rust-lang/crates.io-index)" = "ec350a9417dfd244dc9a6c4a71e13895a4db6b92f0b106f07ebbc3f3bc580cee"
|
16
Cargo.toml
Normal file
16
Cargo.toml
Normal file
@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "cuda-test"
|
||||
version = "0.1.0"
|
||||
authors = ["xeals <xeals@pm.me>"]
|
||||
build = "build.rs"
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
libc = "0.2.53"
|
||||
|
||||
[build-dependencies]
|
||||
cc = "1.0.36"
|
||||
|
||||
[[bin]]
|
||||
name = "cuda-test"
|
||||
path = "src/rust/main.rs"
|
9
build.rs
Normal file
9
build.rs
Normal file
@ -0,0 +1,9 @@
|
||||
fn main() {
|
||||
cc::Build::new()
|
||||
.cuda(true)
|
||||
.flag("-cudart=shared")
|
||||
.file("src/c/test.cu")
|
||||
.compile("libtest.a");
|
||||
println!("cargo:rustc-link-search=native=/opt/cuda/lib64");
|
||||
println!("cargo:rustc-link-lib=cudart");
|
||||
}
|
38
src/c/test.cu
Normal file
38
src/c/test.cu
Normal file
@ -0,0 +1,38 @@
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include "kernel.cu"
|
||||
|
||||
#define CUDA_ERRCHK(fn) { __gpucheck((fn), __FILE__, __LINE__); }
|
||||
|
||||
inline cudaError_t __gpucheck(cudaError_t code, const char *file, int line) {
|
||||
if (code != cudaSuccess) {
|
||||
fprintf(stderr, "CUDA runtime error [%s:%d]: %s\n", file, line, cudaGetErrorString(code));
|
||||
exit(-1);
|
||||
}
|
||||
return code;
|
||||
}
|
||||
|
||||
__global__
|
||||
void mulAll_kernel(int *out, int *in, int n, size_t size) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
i < size;
|
||||
i += blockDim.x + gridDim.x) {
|
||||
out[i] = in[i] * n;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
void mulAll(int *out, const int *input, int n, size_t size) {
|
||||
int *out_d, *in_d, i_d, size_d;
|
||||
|
||||
CUDA_ERRCHK(cudaMalloc((void **) &in_d, sizeof(int) * size));
|
||||
CUDA_ERRCHK(cudaMemcpy(in_d, input, sizeof(int) * size, cudaMemcpyHostToDevice));
|
||||
|
||||
CUDA_ERRCHK(cudaMalloc((void **) &out_d, sizeof(int) * size));
|
||||
mulAll_kernel << < 32, 32 >> > (out_d, in_d, n, size);
|
||||
CUDA_ERRCHK(cudaMemcpy(out, out_d, sizeof(int) * size, cudaMemcpyDeviceToHost));
|
||||
|
||||
CUDA_ERRCHK(cudaFree(in_d));
|
||||
CUDA_ERRCHK(cudaFree(out_d));
|
||||
}
|
||||
}
|
34
src/rust/main.rs
Normal file
34
src/rust/main.rs
Normal file
@ -0,0 +1,34 @@
|
||||
mod cuda {
|
||||
use std::os::raw::c_int;
|
||||
|
||||
use libc::size_t;
|
||||
|
||||
#[link(name = "test", kind = "static")]
|
||||
extern {
|
||||
fn mulAll(out: *const c_int, input: *const c_int, n: c_int, size: size_t);
|
||||
}
|
||||
|
||||
pub fn mul_by(src: &[i32], by: i32) -> Vec<i32> {
|
||||
unsafe {
|
||||
let len = src.len() as size_t;
|
||||
let psrc = src.as_ptr();
|
||||
|
||||
let mut res = Vec::with_capacity(src.len());
|
||||
let pres = res.as_mut_ptr();
|
||||
|
||||
mulAll(pres, psrc, by as c_int, len);
|
||||
|
||||
// Turns out converting to a raw pointer drops the length information.
|
||||
res.set_len(src.len());
|
||||
res
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let v = (1..128).collect::<Vec<_>>();
|
||||
let n = 3;
|
||||
|
||||
let o = cuda::mul_by(&v, n);
|
||||
assert_eq!(o, v.iter().map(|i| i * n).collect::<Vec<_>>(), "output mangled somewhere");
|
||||
}
|
Loading…
Reference in New Issue
Block a user