diff --git a/CHANGELOG.md b/CHANGELOG.md index e034c70399..107a8a23a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -177,6 +177,10 @@ By @wumpf in [#6849](https://github.com/gfx-rs/wgpu/pull/6849). - Add build support for Apple Vision Pro. By @guusw in [#6611](https://github.com/gfx-rs/wgpu/pull/6611). - Add `raw_handle` method to access raw Metal textures in [#6894](https://github.com/gfx-rs/wgpu/pull/6894). +#### D3D12 + +- Support DXR (DirectX Ray-tracing) in wgpu-hal. By @Vecvec in [#6777](https://github.com/gfx-rs/wgpu/pull/6777) + #### Changes ##### Naga diff --git a/examples/src/ray_cube_compute/mod.rs b/examples/src/ray_cube_compute/mod.rs index 801b4796ed..fb5f1abf14 100644 --- a/examples/src/ray_cube_compute/mod.rs +++ b/examples/src/ray_cube_compute/mod.rs @@ -141,7 +141,6 @@ struct Example { impl crate::framework::Example for Example { fn required_features() -> wgpu::Features { wgpu::Features::TEXTURE_BINDING_ARRAY - | wgpu::Features::STORAGE_RESOURCE_BINDING_ARRAY | wgpu::Features::VERTEX_WRITABLE_STORAGE | wgpu::Features::EXPERIMENTAL_RAY_QUERY | wgpu::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE diff --git a/naga/src/back/hlsl/help.rs b/naga/src/back/hlsl/help.rs index 347addd67e..f63c9d2cfd 100644 --- a/naga/src/back/hlsl/help.rs +++ b/naga/src/back/hlsl/help.rs @@ -841,6 +841,9 @@ impl super::Writer<'_, W> { &crate::PredeclaredType::AtomicCompareExchangeWeakResult { .. } => {} } } + if module.special_types.ray_desc.is_some() { + self.write_ray_desc_from_ray_desc_constructor_function(module)?; + } Ok(()) } @@ -852,16 +855,30 @@ impl super::Writer<'_, W> { expressions: &crate::Arena, ) -> BackendResult { for (handle, _) in expressions.iter() { - if let crate::Expression::Compose { ty, .. } = expressions[handle] { - match module.types[ty].inner { - crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { - let constructor = WrappedConstructor { ty }; - if self.wrapped.constructors.insert(constructor) { - self.write_wrapped_constructor_function(module, constructor)?; + match expressions[handle] { + crate::Expression::Compose { ty, .. } => { + match module.types[ty].inner { + crate::TypeInner::Struct { .. } | crate::TypeInner::Array { .. } => { + let constructor = WrappedConstructor { ty }; + if self.wrapped.constructors.insert(constructor) { + self.write_wrapped_constructor_function(module, constructor)?; + } + } + _ => {} + }; + } + crate::Expression::RayQueryGetIntersection { committed, .. } => { + if committed { + if !self.written_committed_intersection { + self.write_committed_intersection_function(module)?; + self.written_committed_intersection = true; } + } else if !self.written_candidate_intersection { + self.write_candidate_intersection_function(module)?; + self.written_candidate_intersection = true; } - _ => {} - }; + } + _ => {} } } Ok(()) diff --git a/naga/src/back/hlsl/keywords.rs b/naga/src/back/hlsl/keywords.rs index 2cb715c42c..c15e17636c 100644 --- a/naga/src/back/hlsl/keywords.rs +++ b/naga/src/back/hlsl/keywords.rs @@ -814,6 +814,7 @@ pub const RESERVED: &[&str] = &[ "TextureBuffer", "ConstantBuffer", "RayQuery", + "RayDesc", // Naga utilities super::writer::MODF_FUNCTION, super::writer::FREXP_FUNCTION, diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index fe7d4f6d67..dcce866bac 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -101,6 +101,7 @@ accessing individual columns by dynamic index. mod conv; mod help; mod keywords; +mod ray; mod storage; mod writer; @@ -331,6 +332,8 @@ pub struct Writer<'a, W> { /// Set of expressions that have associated temporary variables named_expressions: crate::NamedExpressions, wrapped: Wrapped, + written_committed_intersection: bool, + written_candidate_intersection: bool, continue_ctx: back::continue_forward::ContinueCtx, /// A reference to some part of a global variable, lowered to a series of diff --git a/naga/src/back/hlsl/ray.rs b/naga/src/back/hlsl/ray.rs new file mode 100644 index 0000000000..ab57f06a6c --- /dev/null +++ b/naga/src/back/hlsl/ray.rs @@ -0,0 +1,163 @@ +use crate::back::hlsl::BackendResult; +use crate::{RayQueryIntersection, TypeInner}; +use std::fmt::Write; + +impl super::Writer<'_, W> { + // constructs hlsl RayDesc from wgsl RayDesc + pub(super) fn write_ray_desc_from_ray_desc_constructor_function( + &mut self, + module: &crate::Module, + ) -> BackendResult { + write!(self.out, "RayDesc RayDescFromRayDesc_(")?; + self.write_type(module, module.special_types.ray_desc.unwrap())?; + writeln!(self.out, " arg0) {{")?; + writeln!(self.out, " RayDesc ret = (RayDesc)0;")?; + writeln!(self.out, " ret.Origin = arg0.origin;")?; + writeln!(self.out, " ret.TMin = arg0.tmin;")?; + writeln!(self.out, " ret.Direction = arg0.dir;")?; + writeln!(self.out, " ret.TMax = arg0.tmax;")?; + writeln!(self.out, " return ret;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + pub(super) fn write_committed_intersection_function( + &mut self, + module: &crate::Module, + ) -> BackendResult { + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + write!(self.out, " GetCommittedIntersection(")?; + self.write_value_type(module, &TypeInner::RayQuery)?; + writeln!(self.out, " rq) {{")?; + write!(self.out, " ")?; + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + write!(self.out, " ret = (")?; + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + writeln!(self.out, ")0;")?; + writeln!(self.out, " ret.kind = rq.CommittedStatus();")?; + writeln!( + self.out, + " if( rq.CommittedStatus() == COMMITTED_NOTHING) {{}} else {{" + )?; + writeln!(self.out, " ret.t = rq.CommittedRayT();")?; + writeln!( + self.out, + " ret.instance_custom_index = rq.CommittedInstanceID();" + )?; + writeln!( + self.out, + " ret.instance_id = rq.CommittedInstanceIndex();" + )?; + writeln!( + self.out, + " ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex();" + )?; + writeln!( + self.out, + " ret.geometry_index = rq.CommittedGeometryIndex();" + )?; + writeln!( + self.out, + " ret.primitive_index = rq.CommittedPrimitiveIndex();" + )?; + writeln!( + self.out, + " if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) {{" + )?; + writeln!( + self.out, + " ret.barycentrics = rq.CommittedTriangleBarycentrics();" + )?; + writeln!( + self.out, + " ret.front_face = rq.CommittedTriangleFrontFace();" + )?; + writeln!(self.out, " }}")?; + writeln!( + self.out, + " ret.object_to_world = rq.CommittedObjectToWorld4x3();" + )?; + writeln!( + self.out, + " ret.world_to_object = rq.CommittedWorldToObject4x3();" + )?; + writeln!(self.out, " }}")?; + writeln!(self.out, " return ret;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } + pub(super) fn write_candidate_intersection_function( + &mut self, + module: &crate::Module, + ) -> BackendResult { + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + write!(self.out, " GetCandidateIntersection(")?; + self.write_value_type(module, &TypeInner::RayQuery)?; + writeln!(self.out, " rq) {{")?; + write!(self.out, " ")?; + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + write!(self.out, " ret = (")?; + self.write_type(module, module.special_types.ray_intersection.unwrap())?; + writeln!(self.out, ")0;")?; + writeln!(self.out, " CANDIDATE_TYPE kind = rq.CandidateType();")?; + writeln!( + self.out, + " if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) {{" + )?; + writeln!( + self.out, + " ret.kind = {};", + RayQueryIntersection::Triangle as u32 + )?; + writeln!(self.out, " ret.t = rq.CandidateTriangleRayT();")?; + writeln!( + self.out, + " ret.barycentrics = rq.CandidateTriangleBarycentrics();" + )?; + writeln!( + self.out, + " ret.front_face = rq.CandidateTriangleFrontFace();" + )?; + writeln!(self.out, " }} else {{")?; + writeln!( + self.out, + " ret.kind = {};", + RayQueryIntersection::Aabb as u32 + )?; + writeln!(self.out, " }}")?; + + writeln!( + self.out, + " ret.instance_custom_index = rq.CandidateInstanceID();" + )?; + writeln!( + self.out, + " ret.instance_id = rq.CandidateInstanceIndex();" + )?; + writeln!( + self.out, + " ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex();" + )?; + writeln!( + self.out, + " ret.geometry_index = rq.CandidateGeometryIndex();" + )?; + writeln!( + self.out, + " ret.primitive_index = rq.CandidatePrimitiveIndex();" + )?; + writeln!( + self.out, + " ret.object_to_world = rq.CandidateObjectToWorld4x3();" + )?; + writeln!( + self.out, + " ret.world_to_object = rq.CandidateWorldToObject4x3();" + )?; + writeln!(self.out, " return ret;")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + Ok(()) + } +} diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 8dff67f1fc..b5df135766 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -9,7 +9,7 @@ use super::{ use crate::{ back::{self, Baked}, proc::{self, index, ExpressionKindTracker, NameKey}, - valid, Handle, Module, Scalar, ScalarKind, ShaderStage, TypeInner, + valid, Handle, Module, RayQueryFunction, Scalar, ScalarKind, ShaderStage, TypeInner, }; use std::{fmt, mem}; @@ -104,6 +104,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { entry_point_io: Vec::new(), named_expressions: crate::NamedExpressions::default(), wrapped: super::Wrapped::default(), + written_committed_intersection: false, + written_candidate_intersection: false, continue_ctx: back::continue_forward::ContinueCtx::default(), temp_access_chain: Vec::new(), need_bake_expressions: Default::default(), @@ -123,6 +125,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.entry_point_io.clear(); self.named_expressions.clear(); self.wrapped.clear(); + self.written_committed_intersection = false; + self.written_candidate_intersection = false; self.continue_ctx.clear(); self.need_bake_expressions.clear(); } @@ -1218,6 +1222,13 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { TypeInner::Array { base, size, .. } | TypeInner::BindingArray { base, size } => { self.write_array_size(module, base, size)?; } + TypeInner::AccelerationStructure => { + write!(self.out, "RaytracingAccelerationStructure")?; + } + TypeInner::RayQuery => { + // these are constant flags, there are dynamic flags also but constant flags are not supported by naga + write!(self.out, "RayQuery")?; + } _ => return Err(Error::Unimplemented(format!("write_value_type {inner:?}"))), } @@ -1375,15 +1386,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_array_size(module, base, size)?; } - write!(self.out, " = ")?; - // Write the local initializer if needed - if let Some(init) = local.init { - self.write_expr(module, init, func_ctx)?; - } else { - // Zero initialize local variables - self.write_default_init(module, local.ty)?; + match module.types[local.ty].inner { + // from https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#tracerayinline-example-1 it seems that ray queries shouldn't be zeroed + TypeInner::RayQuery => {} + _ => { + write!(self.out, " = ")?; + // Write the local initializer if needed + if let Some(init) = local.init { + self.write_expr(module, init, func_ctx)?; + } else { + // Zero initialize local variables + self.write_default_init(module, local.ty)?; + } + } } - // Finish the local with `;` and add a newline (only for readability) writeln!(self.out, ";")? } @@ -2250,7 +2266,37 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } => { self.write_switch(module, func_ctx, level, selector, cases)?; } - Statement::RayQuery { .. } => unreachable!(), + Statement::RayQuery { query, ref fun } => match *fun { + RayQueryFunction::Initialize { + acceleration_structure, + descriptor, + } => { + write!(self.out, "{level}")?; + self.write_expr(module, query, func_ctx)?; + write!(self.out, ".TraceRayInline(")?; + self.write_expr(module, acceleration_structure, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, descriptor, func_ctx)?; + write!(self.out, ".flags, ")?; + self.write_expr(module, descriptor, func_ctx)?; + write!(self.out, ".cull_mask, ")?; + write!(self.out, "RayDescFromRayDesc_(")?; + self.write_expr(module, descriptor, func_ctx)?; + writeln!(self.out, "));")?; + } + RayQueryFunction::Proceed { result } => { + write!(self.out, "{level}")?; + let name = Baked(result).to_string(); + write!(self.out, "const bool {name} = ")?; + self.named_expressions.insert(result, name); + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".Proceed();")?; + } + RayQueryFunction::Terminate => { + self.write_expr(module, query, func_ctx)?; + writeln!(self.out, ".Abort();")?; + } + }, Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); @@ -3608,8 +3654,17 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_expr(module, reject, func_ctx)?; write!(self.out, ")")? } - // Not supported yet - Expression::RayQueryGetIntersection { .. } => unreachable!(), + Expression::RayQueryGetIntersection { query, committed } => { + if committed { + write!(self.out, "GetCommittedIntersection(")?; + self.write_expr(module, query, func_ctx)?; + write!(self.out, ")")?; + } else { + write!(self.out, "GetCandidateIntersection(")?; + self.write_expr(module, query, func_ctx)?; + write!(self.out, ")")?; + } + } // Nothing to do here, since call expression already cached Expression::CallResult(_) | Expression::AtomicResult { .. } diff --git a/naga/tests/in/ray-query.param.ron b/naga/tests/in/ray-query.param.ron index c400db8c64..481d311fa4 100644 --- a/naga/tests/in/ray-query.param.ron +++ b/naga/tests/in/ray-query.param.ron @@ -11,4 +11,11 @@ per_entry_point_map: {}, inline_samplers: [], ), + hlsl: ( + shader_model: V6_5, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ) ) diff --git a/naga/tests/out/hlsl/ray-query.hlsl b/naga/tests/out/hlsl/ray-query.hlsl new file mode 100644 index 0000000000..9a0a2da1ce --- /dev/null +++ b/naga/tests/out/hlsl/ray-query.hlsl @@ -0,0 +1,152 @@ +struct RayIntersection { + uint kind; + float t; + uint instance_custom_index; + uint instance_id; + uint sbt_record_offset; + uint geometry_index; + uint primitive_index; + float2 barycentrics; + bool front_face; + int _pad9_0; + int _pad9_1; + row_major float4x3 object_to_world; + int _pad10_0; + row_major float4x3 world_to_object; + int _end_pad_0; +}; + +struct RayDesc_ { + uint flags; + uint cull_mask; + float tmin; + float tmax; + float3 origin; + int _pad5_0; + float3 dir; + int _end_pad_0; +}; + +struct Output { + uint visible; + int _pad1_0; + int _pad1_1; + int _pad1_2; + float3 normal; + int _end_pad_0; +}; + +RayDesc RayDescFromRayDesc_(RayDesc_ arg0) { + RayDesc ret = (RayDesc)0; + ret.Origin = arg0.origin; + ret.TMin = arg0.tmin; + ret.Direction = arg0.dir; + ret.TMax = arg0.tmax; + return ret; +} + +RaytracingAccelerationStructure acc_struct : register(t0); +RWByteAddressBuffer output : register(u1); + +RayDesc_ ConstructRayDesc_(uint arg0, uint arg1, float arg2, float arg3, float3 arg4, float3 arg5) { + RayDesc_ ret = (RayDesc_)0; + ret.flags = arg0; + ret.cull_mask = arg1; + ret.tmin = arg2; + ret.tmax = arg3; + ret.origin = arg4; + ret.dir = arg5; + return ret; +} + +RayIntersection GetCommittedIntersection(RayQuery rq) { + RayIntersection ret = (RayIntersection)0; + ret.kind = rq.CommittedStatus(); + if( rq.CommittedStatus() == COMMITTED_NOTHING) {} else { + ret.t = rq.CommittedRayT(); + ret.instance_custom_index = rq.CommittedInstanceID(); + ret.instance_id = rq.CommittedInstanceIndex(); + ret.sbt_record_offset = rq.CommittedInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CommittedGeometryIndex(); + ret.primitive_index = rq.CommittedPrimitiveIndex(); + if( rq.CommittedStatus() == COMMITTED_TRIANGLE_HIT ) { + ret.barycentrics = rq.CommittedTriangleBarycentrics(); + ret.front_face = rq.CommittedTriangleFrontFace(); + } + ret.object_to_world = rq.CommittedObjectToWorld4x3(); + ret.world_to_object = rq.CommittedWorldToObject4x3(); + } + return ret; +} + +RayIntersection query_loop(float3 pos, float3 dir, RaytracingAccelerationStructure acs) +{ + RayQuery rq_1; + + rq_1.TraceRayInline(acs, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos, dir))); + while(true) { + const bool _e9 = rq_1.Proceed(); + if (_e9) { + } else { + break; + } + { + } + } + const RayIntersection rayintersection = GetCommittedIntersection(rq_1); + return rayintersection; +} + +float3 get_torus_normal(float3 world_point, RayIntersection intersection) +{ + float3 local_point = mul(float4(world_point, 1.0), intersection.world_to_object); + float2 point_on_guiding_line = (normalize(local_point.xy) * 2.4); + float3 world_point_on_guiding_line = mul(float4(point_on_guiding_line, 0.0, 1.0), intersection.object_to_world); + return normalize((world_point - world_point_on_guiding_line)); +} + +[numthreads(1, 1, 1)] +void main() +{ + float3 pos_1 = (0.0).xxx; + float3 dir_1 = float3(0.0, 1.0, 0.0); + const RayIntersection _e7 = query_loop(pos_1, dir_1, acc_struct); + output.Store(0, asuint(uint((_e7.kind == 0u)))); + const float3 _e18 = get_torus_normal((dir_1 * _e7.t), _e7); + output.Store3(16, asuint(_e18)); + return; +} + +RayIntersection GetCandidateIntersection(RayQuery rq) { + RayIntersection ret = (RayIntersection)0; + CANDIDATE_TYPE kind = rq.CandidateType(); + if (kind == CANDIDATE_NON_OPAQUE_TRIANGLE) { + ret.kind = 1; + ret.t = rq.CandidateTriangleRayT(); + ret.barycentrics = rq.CandidateTriangleBarycentrics(); + ret.front_face = rq.CandidateTriangleFrontFace(); + } else { + ret.kind = 3; + } + ret.instance_custom_index = rq.CandidateInstanceID(); + ret.instance_id = rq.CandidateInstanceIndex(); + ret.sbt_record_offset = rq.CandidateInstanceContributionToHitGroupIndex(); + ret.geometry_index = rq.CandidateGeometryIndex(); + ret.primitive_index = rq.CandidatePrimitiveIndex(); + ret.object_to_world = rq.CandidateObjectToWorld4x3(); + ret.world_to_object = rq.CandidateWorldToObject4x3(); + return ret; +} + +[numthreads(1, 1, 1)] +void main_candidate() +{ + RayQuery rq; + + float3 pos_2 = (0.0).xxx; + float3 dir_2 = float3(0.0, 1.0, 0.0); + rq.TraceRayInline(acc_struct, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).flags, ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2).cull_mask, RayDescFromRayDesc_(ConstructRayDesc_(4u, 255u, 0.1, 100.0, pos_2, dir_2))); + RayIntersection intersection_1 = GetCandidateIntersection(rq); + output.Store(0, asuint(uint((intersection_1.kind == 3u)))); + return; +} diff --git a/naga/tests/out/hlsl/ray-query.ron b/naga/tests/out/hlsl/ray-query.ron new file mode 100644 index 0000000000..a31e1db125 --- /dev/null +++ b/naga/tests/out/hlsl/ray-query.ron @@ -0,0 +1,16 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_5", + ), + ( + entry_point:"main_candidate", + target_profile:"cs_6_5", + ), + ], +) diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index db5fcdf19e..73ac99455e 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -875,7 +875,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), - ("ray-query", Targets::SPIRV | Targets::METAL), + ("ray-query", Targets::SPIRV | Targets::METAL | Targets::HLSL), ("hlsl-keyword", Targets::HLSL), ( "constructors", diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs index 9987380c34..3e048e9396 100644 --- a/wgpu-hal/examples/ray-traced-triangle/main.rs +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -284,7 +284,7 @@ impl Example { dbg!(&surface_caps.formats); let surface_format = if surface_caps .formats - .contains(&wgt::TextureFormat::Rgba8Snorm) + .contains(&wgt::TextureFormat::Rgba8Unorm) { wgt::TextureFormat::Rgba8Unorm } else { @@ -473,7 +473,8 @@ impl Example { vertex_buffer: Some(&vertices_buffer), first_vertex: 0, vertex_format: wgt::VertexFormat::Float32x3, - vertex_count: vertices.len() as u32, + // each vertex is 3 floats, and floats are stored raw in the array + vertex_count: vertices.len() as u32 / 3, vertex_stride: 3 * 4, indices: indices_buffer.as_ref().map(|(buf, len)| { hal::AccelerationStructureTriangleIndices { diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 7859f06e5d..4558fb92e2 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -244,7 +244,6 @@ impl super::Adapter { _ => unreachable!(), } }; - let private_caps = super::PrivateCapabilities { instance_flags, heterogeneous_resource_heaps: options.ResourceHeapTier @@ -395,6 +394,27 @@ impl super::Adapter { && hr.is_ok() && features1.WaveOps.as_bool(), ); + let mut features5 = Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS5::default(); + let has_features5 = unsafe { + device.CheckFeatureSupport( + Direct3D12::D3D12_FEATURE_D3D12_OPTIONS5, + <*mut _>::cast(&mut features5), + size_of_val(&features5) as u32, + ) + } + .is_ok(); + + // Since all features for raytracing pipeline (geometry index) and ray queries both come + // from here, there is no point in adding an extra call here given that there will be no + // feature using EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE if all these are not met. + // Once ray tracing pipelines are supported they also will go here + features.set( + wgt::Features::EXPERIMENTAL_RAY_QUERY + | wgt::Features::EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE, + features5.RaytracingTier == Direct3D12::D3D12_RAYTRACING_TIER_1_1 + && shader_model >= naga::back::hlsl::ShaderModel::V6_5 + && has_features5, + ); let atomic_int64_on_typed_resource_supported = { let mut features9 = Direct3D12::D3D12_FEATURE_DATA_D3D12_OPTIONS9::default(); @@ -529,8 +549,9 @@ impl super::Adapter { // Direct3D correctly bounds-checks all array accesses: // https://microsoft.github.io/DirectX-Specs/d3d/archive/D3D11_3_FunctionalSpec.htm#18.6.8.2%20Device%20Memory%20Reads uniform_bounds_check_alignment: wgt::BufferSize::new(1).unwrap(), - raw_tlas_instance_size: 0, - ray_tracing_scratch_buffer_alignment: 0, + raw_tlas_instance_size: size_of::(), + ray_tracing_scratch_buffer_alignment: + Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BYTE_ALIGNMENT, }, downlevel, }, diff --git a/wgpu-hal/src/dx12/command.rs b/wgpu-hal/src/dx12/command.rs index 9296a20393..99cee37373 100644 --- a/wgpu-hal/src/dx12/command.rs +++ b/wgpu-hal/src/dx12/command.rs @@ -1,12 +1,15 @@ -use std::{mem, ops::Range}; - -use windows::Win32::{Foundation, Graphics::Direct3D12}; - use super::conv; use crate::{ auxil::{self, dxgi::result::HResult as _}, dx12::borrow_interface_temporarily, + AccelerationStructureEntries, +}; +use std::{mem, ops::Range}; +use windows::Win32::{ + Foundation, + Graphics::{Direct3D12, Dxgi}, }; +use windows_core::Interface; fn make_box(origin: &wgt::Origin3d, size: &crate::CopyExtent) -> Direct3D12::D3D12_BOX { Direct3D12::D3D12_BOX { @@ -777,8 +780,8 @@ impl crate::CommandEncoder for super::CommandEncoder { // ) // TODO: Replace with the above in the next breaking windows-rs release, // when https://github.com/microsoft/win32metadata/pull/1971 is in. - (windows_core::Interface::vtable(list).ClearDepthStencilView)( - windows_core::Interface::as_raw(list), + (Interface::vtable(list).ClearDepthStencilView)( + Interface::as_raw(list), ds_view, flags, ds.clear_value.0, @@ -1259,7 +1262,7 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn build_acceleration_structures<'a, T>( &mut self, _descriptor_count: u32, - _descriptors: T, + descriptors: T, ) where super::Api: 'a, T: IntoIterator< @@ -1272,13 +1275,189 @@ impl crate::CommandEncoder for super::CommandEncoder { { // Implement using `BuildRaytracingAccelerationStructure`: // https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#buildraytracingaccelerationstructure - todo!() + let list = self + .list + .as_ref() + .unwrap() + .cast::() + .unwrap(); + for descriptor in descriptors { + // TODO: This is the same as getting build sizes apart from requiring buffers, should this be de-duped? + let mut geometry_desc; + let ty; + let inputs0; + let num_desc; + match descriptor.entries { + AccelerationStructureEntries::Instances(instances) => { + let desc_address = unsafe { + instances + .buffer + .expect("needs buffer to build") + .resource + .GetGPUVirtualAddress() + } + instances.offset as u64; + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + InstanceDescs: desc_address, + }; + num_desc = instances.count; + } + AccelerationStructureEntries::Triangles(triangles) => { + geometry_desc = Vec::with_capacity(triangles.len()); + for triangle in triangles { + let transform_address = + triangle.transform.as_ref().map_or(0, |transform| unsafe { + transform.buffer.resource.GetGPUVirtualAddress() + + transform.offset as u64 + }); + let index_format = triangle + .indices + .as_ref() + .map_or(Dxgi::Common::DXGI_FORMAT_UNKNOWN, |indices| { + auxil::dxgi::conv::map_index_format(indices.format) + }); + let vertex_format = + auxil::dxgi::conv::map_vertex_format(triangle.vertex_format); + let index_count = + triangle.indices.as_ref().map_or(0, |indices| indices.count); + let index_address = triangle.indices.as_ref().map_or(0, |indices| unsafe { + indices + .buffer + .expect("needs buffer to build") + .resource + .GetGPUVirtualAddress() + + indices.offset as u64 + }); + let vertex_address = unsafe { + triangle + .vertex_buffer + .expect("needs buffer to build") + .resource + .GetGPUVirtualAddress() + + (triangle.first_vertex as u64 * triangle.vertex_stride) + }; + + let triangle_desc = Direct3D12::D3D12_RAYTRACING_GEOMETRY_TRIANGLES_DESC { + Transform3x4: transform_address, + IndexFormat: index_format, + VertexFormat: vertex_format, + IndexCount: index_count, + VertexCount: triangle.vertex_count, + IndexBuffer: index_address, + VertexBuffer: Direct3D12::D3D12_GPU_VIRTUAL_ADDRESS_AND_STRIDE { + StartAddress: vertex_address, + StrideInBytes: triangle.vertex_stride, + }, + }; + + geometry_desc.push(Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC { + Type: Direct3D12::D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES, + Flags: conv::map_acceleration_structure_geometry_flags(triangle.flags), + Anonymous: Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC_0 { + Triangles: triangle_desc, + }, + }) + } + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + pGeometryDescs: geometry_desc.as_ptr(), + }; + num_desc = geometry_desc.len() as u32; + } + AccelerationStructureEntries::AABBs(aabbs) => { + geometry_desc = Vec::with_capacity(aabbs.len()); + for aabb in aabbs { + let aabb_address = unsafe { + aabb.buffer + .expect("needs buffer to build") + .resource + .GetGPUVirtualAddress() + + (aabb.offset as u64 * aabb.stride) + }; + + let aabb_desc = Direct3D12::D3D12_RAYTRACING_GEOMETRY_AABBS_DESC { + AABBCount: aabb.count as u64, + AABBs: Direct3D12::D3D12_GPU_VIRTUAL_ADDRESS_AND_STRIDE { + StartAddress: aabb_address, + StrideInBytes: aabb.stride, + }, + }; + + geometry_desc.push(Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC { + Type: Direct3D12::D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS, + Flags: conv::map_acceleration_structure_geometry_flags(aabb.flags), + Anonymous: Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC_0 { + AABBs: aabb_desc, + }, + }) + } + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + pGeometryDescs: geometry_desc.as_ptr(), + }; + num_desc = geometry_desc.len() as u32; + } + }; + let acceleration_structure_inputs = + Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS { + Type: ty, + Flags: conv::map_acceleration_structure_build_flags( + descriptor.flags, + Some(descriptor.mode), + ), + NumDescs: num_desc, + DescsLayout: Direct3D12::D3D12_ELEMENTS_LAYOUT_ARRAY, + Anonymous: inputs0, + }; + + let dst_acceleration_structure_address = unsafe { + descriptor + .destination_acceleration_structure + .resource + .GetGPUVirtualAddress() + }; + let src_acceleration_structure_address = descriptor + .source_acceleration_structure + .as_ref() + .map_or(0, |source| unsafe { + source.resource.GetGPUVirtualAddress() + }); + let scratch_address = unsafe { + descriptor.scratch_buffer.resource.GetGPUVirtualAddress() + + descriptor.scratch_buffer_offset + }; + + let desc = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC { + DestAccelerationStructureData: dst_acceleration_structure_address, + Inputs: acceleration_structure_inputs, + SourceAccelerationStructureData: src_acceleration_structure_address, + ScratchAccelerationStructureData: scratch_address, + }; + unsafe { list.BuildRaytracingAccelerationStructure(&desc, None) }; + } } unsafe fn place_acceleration_structure_barrier( &mut self, _barriers: crate::AccelerationStructureBarrier, ) { - todo!() + // TODO: This is not very optimal, we should be using [enhanced barriers](https://microsoft.github.io/DirectX-Specs/d3d/D3D12EnhancedBarriers.html) if possible + let list = self + .list + .as_ref() + .unwrap() + .cast::() + .unwrap(); + unsafe { + list.ResourceBarrier(&[Direct3D12::D3D12_RESOURCE_BARRIER { + Type: Direct3D12::D3D12_RESOURCE_BARRIER_TYPE_UAV, + Flags: Direct3D12::D3D12_RESOURCE_BARRIER_FLAG_NONE, + Anonymous: Direct3D12::D3D12_RESOURCE_BARRIER_0 { + UAV: mem::ManuallyDrop::new(Direct3D12::D3D12_RESOURCE_UAV_BARRIER { + pResource: Default::default(), + }), + }, + }]) + } } } diff --git a/wgpu-hal/src/dx12/conv.rs b/wgpu-hal/src/dx12/conv.rs index 3457d6446e..5117378942 100644 --- a/wgpu-hal/src/dx12/conv.rs +++ b/wgpu-hal/src/dx12/conv.rs @@ -112,7 +112,7 @@ pub fn map_binding_type(ty: &wgt::BindingType) -> Direct3D12::D3D12_DESCRIPTOR_R .. } | Bt::StorageTexture { .. } => Direct3D12::D3D12_DESCRIPTOR_RANGE_TYPE_UAV, - Bt::AccelerationStructure => todo!(), + Bt::AccelerationStructure => Direct3D12::D3D12_DESCRIPTOR_RANGE_TYPE_SRV, } } @@ -350,3 +350,51 @@ pub fn map_depth_stencil(ds: &wgt::DepthStencilState) -> Direct3D12::D3D12_DEPTH BackFace: map_stencil_face(&ds.stencil.back), } } + +pub(crate) fn map_acceleration_structure_build_flags( + flags: wgt::AccelerationStructureFlags, + mode: Option, +) -> Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAGS { + let mut d3d_flags = Default::default(); + if flags.contains(wgt::AccelerationStructureFlags::ALLOW_COMPACTION) { + d3d_flags |= + Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_COMPACTION; + } + + if flags.contains(wgt::AccelerationStructureFlags::ALLOW_UPDATE) { + d3d_flags |= Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_ALLOW_UPDATE; + } + + if flags.contains(wgt::AccelerationStructureFlags::LOW_MEMORY) { + d3d_flags |= Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_MINIMIZE_MEMORY; + } + + if flags.contains(wgt::AccelerationStructureFlags::PREFER_FAST_BUILD) { + d3d_flags |= + Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_BUILD; + } + + if flags.contains(wgt::AccelerationStructureFlags::PREFER_FAST_TRACE) { + d3d_flags |= + Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PREFER_FAST_TRACE; + } + + if let Some(crate::AccelerationStructureBuildMode::Update) = mode { + d3d_flags |= Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_BUILD_FLAG_PERFORM_UPDATE + } + + d3d_flags +} + +pub(crate) fn map_acceleration_structure_geometry_flags( + flags: wgt::AccelerationStructureGeometryFlags, +) -> Direct3D12::D3D12_RAYTRACING_GEOMETRY_FLAGS { + let mut d3d_flags = Default::default(); + if flags.contains(wgt::AccelerationStructureGeometryFlags::OPAQUE) { + d3d_flags |= Direct3D12::D3D12_RAYTRACING_GEOMETRY_FLAG_OPAQUE; + } + if flags.contains(wgt::AccelerationStructureGeometryFlags::NO_DUPLICATE_ANY_HIT_INVOCATION) { + d3d_flags |= Direct3D12::D3D12_RAYTRACING_GEOMETRY_FLAG_NO_DUPLICATE_ANYHIT_INVOCATION; + } + d3d_flags +} diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 20dc20164f..b9a825845a 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -2,7 +2,7 @@ use std::{ ffi, mem::{self, size_of, size_of_val}, num::NonZeroU32, - ptr, + ptr, slice, sync::Arc, time::{Duration, Instant}, }; @@ -21,7 +21,7 @@ use super::{conv, descriptor, D3D12Lib}; use crate::{ auxil::{self, dxgi::result::HResult}, dx12::{borrow_optional_interface_temporarily, shader_compilation, Event}, - TlasInstance, + AccelerationStructureEntries, TlasInstance, }; // this has to match Naga's HLSL backend, and also needs to be null-terminated @@ -763,7 +763,12 @@ impl crate::Device for super::Device { &self, desc: &crate::BindGroupLayoutDescriptor, ) -> Result { - let (mut num_buffer_views, mut num_samplers, mut num_texture_views) = (0, 0, 0); + let ( + mut num_buffer_views, + mut num_samplers, + mut num_texture_views, + mut num_acceleration_structures, + ) = (0, 0, 0, 0); for entry in desc.entries.iter() { let count = entry.count.map_or(1, NonZeroU32::get); match entry.ty { @@ -776,13 +781,13 @@ impl crate::Device for super::Device { num_texture_views += count } wgt::BindingType::Sampler { .. } => num_samplers += count, - wgt::BindingType::AccelerationStructure => todo!(), + wgt::BindingType::AccelerationStructure => num_acceleration_structures += count, } } self.counters.bind_group_layouts.add(1); - let num_views = num_buffer_views + num_texture_views; + let num_views = num_buffer_views + num_texture_views + num_acceleration_structures; Ok(super::BindGroupLayout { entries: desc.entries.to_vec(), cpu_heap_views: if num_views != 0 { @@ -1389,7 +1394,33 @@ impl crate::Device for super::Device { cpu_samplers.as_mut().unwrap().stage.push(data.handle.raw); } } - wgt::BindingType::AccelerationStructure => todo!(), + wgt::BindingType::AccelerationStructure => { + let start = entry.resource_index as usize; + let end = start + entry.count as usize; + for data in &desc.acceleration_structures[start..end] { + let inner = cpu_views.as_mut().unwrap(); + let cpu_index = inner.stage.len() as u32; + let handle = desc.layout.cpu_heap_views.as_ref().unwrap().at(cpu_index); + let raw_desc = Direct3D12::D3D12_SHADER_RESOURCE_VIEW_DESC { + Format: Dxgi::Common::DXGI_FORMAT_UNKNOWN, + Shader4ComponentMapping: + Direct3D12::D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING, + ViewDimension: + Direct3D12::D3D12_SRV_DIMENSION_RAYTRACING_ACCELERATION_STRUCTURE, + Anonymous: Direct3D12::D3D12_SHADER_RESOURCE_VIEW_DESC_0 { + RaytracingAccelerationStructure: + Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_SRV { + Location: unsafe { data.resource.GetGPUVirtualAddress() }, + }, + }, + }; + unsafe { + self.raw + .CreateShaderResourceView(None, Some(&raw_desc), handle) + }; + inner.stage.push(handle); + } + } } } @@ -1888,36 +1919,167 @@ impl crate::Device for super::Device { unsafe fn get_acceleration_structure_build_sizes<'a>( &self, - _desc: &crate::GetAccelerationStructureBuildSizesDescriptor<'a, super::Buffer>, + desc: &crate::GetAccelerationStructureBuildSizesDescriptor<'a, super::Buffer>, ) -> crate::AccelerationStructureBuildSizes { - // Implement using `GetRaytracingAccelerationStructurePrebuildInfo`: - // https://microsoft.github.io/DirectX-Specs/d3d/Raytracing.html#getraytracingaccelerationstructureprebuildinfo - todo!() + let mut geometry_desc; + let device5 = self.raw.cast::().unwrap(); + let ty; + let inputs0; + let num_desc; + match desc.entries { + AccelerationStructureEntries::Instances(instances) => { + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + InstanceDescs: 0, + }; + num_desc = instances.count; + } + AccelerationStructureEntries::Triangles(triangles) => { + geometry_desc = Vec::with_capacity(triangles.len()); + for triangle in triangles { + let index_format = triangle + .indices + .as_ref() + .map_or(Dxgi::Common::DXGI_FORMAT_UNKNOWN, |indices| { + auxil::dxgi::conv::map_index_format(indices.format) + }); + let index_count = triangle.indices.as_ref().map_or(0, |indices| indices.count); + + let triangle_desc = Direct3D12::D3D12_RAYTRACING_GEOMETRY_TRIANGLES_DESC { + Transform3x4: 0, + IndexFormat: index_format, + VertexFormat: auxil::dxgi::conv::map_vertex_format(triangle.vertex_format), + IndexCount: index_count, + VertexCount: triangle.vertex_count, + IndexBuffer: 0, + VertexBuffer: Direct3D12::D3D12_GPU_VIRTUAL_ADDRESS_AND_STRIDE { + StartAddress: 0, + StrideInBytes: triangle.vertex_stride, + }, + }; + + geometry_desc.push(Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC { + Type: Direct3D12::D3D12_RAYTRACING_GEOMETRY_TYPE_TRIANGLES, + Flags: conv::map_acceleration_structure_geometry_flags(triangle.flags), + Anonymous: Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC_0 { + Triangles: triangle_desc, + }, + }) + } + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + pGeometryDescs: geometry_desc.as_ptr(), + }; + num_desc = geometry_desc.len() as u32; + } + AccelerationStructureEntries::AABBs(aabbs) => { + geometry_desc = Vec::with_capacity(aabbs.len()); + for aabb in aabbs { + let aabb_desc = Direct3D12::D3D12_RAYTRACING_GEOMETRY_AABBS_DESC { + AABBCount: aabb.count as u64, + AABBs: Direct3D12::D3D12_GPU_VIRTUAL_ADDRESS_AND_STRIDE { + StartAddress: 0, + StrideInBytes: aabb.stride, + }, + }; + geometry_desc.push(Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC { + Type: Direct3D12::D3D12_RAYTRACING_GEOMETRY_TYPE_PROCEDURAL_PRIMITIVE_AABBS, + Flags: conv::map_acceleration_structure_geometry_flags(aabb.flags), + Anonymous: Direct3D12::D3D12_RAYTRACING_GEOMETRY_DESC_0 { + AABBs: aabb_desc, + }, + }) + } + ty = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_TYPE_BOTTOM_LEVEL; + inputs0 = Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS_0 { + pGeometryDescs: geometry_desc.as_ptr(), + }; + num_desc = geometry_desc.len() as u32; + } + }; + let acceleration_structure_inputs = + Direct3D12::D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_INPUTS { + Type: ty, + Flags: conv::map_acceleration_structure_build_flags(desc.flags, None), + NumDescs: num_desc, + DescsLayout: Direct3D12::D3D12_ELEMENTS_LAYOUT_ARRAY, + Anonymous: inputs0, + }; + let mut info = Direct3D12::D3D12_RAYTRACING_ACCELERATION_STRUCTURE_PREBUILD_INFO::default(); + unsafe { + device5.GetRaytracingAccelerationStructurePrebuildInfo( + &acceleration_structure_inputs, + &mut info, + ) + }; + crate::AccelerationStructureBuildSizes { + acceleration_structure_size: info.ResultDataMaxSizeInBytes, + update_scratch_size: info.UpdateScratchDataSizeInBytes, + build_scratch_size: info.ScratchDataSizeInBytes, + } } unsafe fn get_acceleration_structure_device_address( &self, - _acceleration_structure: &super::AccelerationStructure, + acceleration_structure: &super::AccelerationStructure, ) -> wgt::BufferAddress { - // Implement using `GetGPUVirtualAddress`: - // https://docs.microsoft.com/en-us/windows/win32/api/d3d12/nf-d3d12-id3d12resource-getgpuvirtualaddress - todo!() + unsafe { acceleration_structure.resource.GetGPUVirtualAddress() } } unsafe fn create_acceleration_structure( &self, - _desc: &crate::AccelerationStructureDescriptor, + desc: &crate::AccelerationStructureDescriptor, ) -> Result { // Create a D3D12 resource as per-usual. - todo!() + let size = desc.size; + + let raw_desc = Direct3D12::D3D12_RESOURCE_DESC { + Dimension: Direct3D12::D3D12_RESOURCE_DIMENSION_BUFFER, + Alignment: 0, + Width: size, + Height: 1, + DepthOrArraySize: 1, + MipLevels: 1, + Format: Dxgi::Common::DXGI_FORMAT_UNKNOWN, + SampleDesc: Dxgi::Common::DXGI_SAMPLE_DESC { + Count: 1, + Quality: 0, + }, + Layout: Direct3D12::D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + // TODO: when moving to enhanced barriers use Direct3D12::D3D12_RESOURCE_FLAG_RAYTRACING_ACCELERATION_STRUCTURE + Flags: Direct3D12::D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, + }; + + let (resource, allocation) = + super::suballocation::create_acceleration_structure_resource(self, desc, raw_desc)?; + + if let Some(label) = desc.label { + unsafe { resource.SetName(&windows::core::HSTRING::from(label)) } + .into_device_result("SetName")?; + } + + // for some reason there is no counter for acceleration structures + + Ok(super::AccelerationStructure { + resource, + allocation, + }) } unsafe fn destroy_acceleration_structure( &self, - _acceleration_structure: super::AccelerationStructure, + mut acceleration_structure: super::AccelerationStructure, ) { - // Destroy a D3D12 resource as per-usual. - todo!() + if let Some(alloc) = acceleration_structure.allocation.take() { + // Resource should be dropped before suballocation is freed + drop(acceleration_structure); + + super::suballocation::free_acceleration_structure_allocation( + self, + alloc, + &self.mem_allocator, + ); + } } fn get_internal_counters(&self) -> wgt::HalCounters { @@ -1954,7 +2116,21 @@ impl crate::Device for super::Device { }) } - fn tlas_instance_to_bytes(&self, _instance: TlasInstance) -> Vec { - todo!() + fn tlas_instance_to_bytes(&self, instance: TlasInstance) -> Vec { + const MAX_U24: u32 = (1u32 << 24u32) - 1u32; + let temp = Direct3D12::D3D12_RAYTRACING_INSTANCE_DESC { + Transform: instance.transform, + _bitfield1: (instance.custom_index & MAX_U24) | (u32::from(instance.mask) << 24), + _bitfield2: 0, + AccelerationStructure: instance.blas_address, + }; + let temp: *const _ = &temp; + unsafe { + slice::from_raw_parts( + temp.cast::(), + size_of::(), + ) + .to_vec() + } } } diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index d58d79300a..809d53c74d 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -1031,7 +1031,10 @@ pub struct PipelineCache; impl crate::DynPipelineCache for PipelineCache {} #[derive(Debug)] -pub struct AccelerationStructure {} +pub struct AccelerationStructure { + resource: Direct3D12::ID3D12Resource, + allocation: Option, +} impl crate::DynAccelerationStructure for AccelerationStructure {} diff --git a/wgpu-hal/src/dx12/suballocation.rs b/wgpu-hal/src/dx12/suballocation.rs index bdb3e85129..2b0cbf8a47 100644 --- a/wgpu-hal/src/dx12/suballocation.rs +++ b/wgpu-hal/src/dx12/suballocation.rs @@ -151,6 +151,54 @@ pub(crate) fn create_texture_resource( Ok((resource, Some(AllocationWrapper { allocation }))) } +pub(crate) fn create_acceleration_structure_resource( + device: &crate::dx12::Device, + desc: &crate::AccelerationStructureDescriptor, + raw_desc: Direct3D12::D3D12_RESOURCE_DESC, +) -> Result<(Direct3D12::ID3D12Resource, Option), crate::DeviceError> { + // Workaround for Intel Xe drivers + if !device.private_caps.suballocation_supported { + return create_committed_acceleration_structure_resource(device, desc, raw_desc) + .map(|resource| (resource, None)); + } + + let location = MemoryLocation::GpuOnly; + + let name = desc.label.unwrap_or("Unlabeled acceleration structure"); + + let mut allocator = device.mem_allocator.lock(); + + let allocation_desc = AllocationCreateDesc::from_d3d12_resource_desc( + allocator.allocator.device(), + &raw_desc, + name, + location, + ); + let allocation = allocator.allocator.allocate(&allocation_desc)?; + let mut resource = None; + + unsafe { + device.raw.CreatePlacedResource( + allocation.heap(), + allocation.offset(), + &raw_desc, + Direct3D12::D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, + None, + &mut resource, + ) + } + .into_device_result("Placed acceleration structure creation")?; + + let resource = resource.ok_or(crate::DeviceError::Unexpected)?; + + device + .counters + .acceleration_structure_memory + .add(allocation.size() as isize); + + Ok((resource, Some(AllocationWrapper { allocation }))) +} + pub(crate) fn free_buffer_allocation( device: &crate::dx12::Device, allocation: AllocationWrapper, @@ -183,6 +231,22 @@ pub(crate) fn free_texture_allocation( }; } +pub(crate) fn free_acceleration_structure_allocation( + device: &crate::dx12::Device, + allocation: AllocationWrapper, + allocator: &Mutex, +) { + device + .counters + .acceleration_structure_memory + .sub(allocation.allocation.size() as isize); + match allocator.lock().allocator.free(allocation.allocation) { + Ok(_) => (), + // TODO: Don't panic here + Err(e) => panic!("Failed to destroy dx12 acceleration structure, {e}"), + }; +} + impl From for crate::DeviceError { fn from(result: gpu_allocator::AllocationError) -> Self { match result { @@ -304,3 +368,40 @@ pub(crate) fn create_committed_texture_resource( resource.ok_or(crate::DeviceError::Unexpected) } + +pub(crate) fn create_committed_acceleration_structure_resource( + device: &crate::dx12::Device, + _desc: &crate::AccelerationStructureDescriptor, + raw_desc: Direct3D12::D3D12_RESOURCE_DESC, +) -> Result { + let heap_properties = Direct3D12::D3D12_HEAP_PROPERTIES { + Type: Direct3D12::D3D12_HEAP_TYPE_CUSTOM, + CPUPageProperty: Direct3D12::D3D12_CPU_PAGE_PROPERTY_NOT_AVAILABLE, + MemoryPoolPreference: match device.private_caps.memory_architecture { + crate::dx12::MemoryArchitecture::NonUnified => Direct3D12::D3D12_MEMORY_POOL_L1, + _ => Direct3D12::D3D12_MEMORY_POOL_L0, + }, + CreationNodeMask: 0, + VisibleNodeMask: 0, + }; + + let mut resource = None; + + unsafe { + device.raw.CreateCommittedResource( + &heap_properties, + if device.private_caps.heap_create_not_zeroed { + Direct3D12::D3D12_HEAP_FLAG_CREATE_NOT_ZEROED + } else { + Direct3D12::D3D12_HEAP_FLAG_NONE + }, + &raw_desc, + Direct3D12::D3D12_RESOURCE_STATE_RAYTRACING_ACCELERATION_STRUCTURE, + None, + &mut resource, + ) + } + .into_device_result("Committed acceleration structure creation")?; + + resource.ok_or(crate::DeviceError::Unexpected) +} diff --git a/wgpu-types/src/counters.rs b/wgpu-types/src/counters.rs index 6137a6a2b4..ff38b33c66 100644 --- a/wgpu-types/src/counters.rs +++ b/wgpu-types/src/counters.rs @@ -126,6 +126,8 @@ pub struct HalCounters { pub buffer_memory: InternalCounter, /// Amount of allocated gpu memory attributed to textures, in bytes. pub texture_memory: InternalCounter, + /// Amount of allocated gpu memory attributed to acceleration structures, in bytes. + pub acceleration_structure_memory: InternalCounter, /// Number of gpu memory allocations. pub memory_allocations: InternalCounter, } diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 1437437933..a04bbf3f38 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1465,7 +1465,7 @@ impl dispatch::DeviceInterface for CoreDevice { global.device_create_tlas(self.id, &desc.map_label(|l| l.map(Borrowed)), None); if let Some(cause) = error { self.context - .handle_error(&self.error_sink, cause, desc.label, "Device::create_blas"); + .handle_error(&self.error_sink, cause, desc.label, "Device::create_tlas"); } CoreTlas { context: self.context.clone(),