Skip to content

Commit

Permalink
Fix HLSL single scalar loads (#7104)
Browse files Browse the repository at this point in the history
  • Loading branch information
Vecvec authored Feb 12, 2025
1 parent 0dd6a1c commit 5af9e30
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).

#### Dx12

- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993)
- Fix HLSL storage format generation. By @Vecvec in [#6993](https://github.com/gfx-rs/wgpu/pull/6993) and [#7104](https://github.com/gfx-rs/wgpu/pull/7104)
- Fix 3D storage texture bindings. By @SparkyPotato in [#7071](https://github.com/gfx-rs/wgpu/pull/7071)

#### WebGPU
Expand Down
106 changes: 101 additions & 5 deletions naga/src/back/hlsl/help.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use super::{
writer::{EXTRACT_BITS_FUNCTION, INSERT_BITS_FUNCTION},
BackendResult,
};
use crate::{arena::Handle, proc::NameKey};
use crate::{arena::Handle, proc::NameKey, ScalarKind};
use std::fmt::Write;

#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)]
Expand Down Expand Up @@ -128,6 +128,8 @@ impl From<crate::ImageQuery> for ImageQuery {
}
}

pub(super) const IMAGE_STORAGE_LOAD_SCALAR_WRAPPER: &str = "LoadedStorageValueFrom";

impl<W: Write> super::Writer<'_, W> {
pub(super) fn write_image_type(
&mut self,
Expand Down Expand Up @@ -513,6 +515,60 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

/// Writes the conversion from a single length storage texture load to a vec4 with the loaded
/// scalar in its `x` component, 1 in its `a` component and 0 everywhere else.
fn write_loaded_scalar_to_storage_loaded_value(
&mut self,
scalar_type: crate::Scalar,
) -> BackendResult {
const ARGUMENT_VARIABLE_NAME: &str = "arg";
const RETURN_VARIABLE_NAME: &str = "ret";

let zero;
let one;
match scalar_type.kind {
ScalarKind::Sint => {
assert_eq!(
scalar_type.width, 4,
"Scalar {scalar_type:?} is not a result from any storage format"
);
zero = "0";
one = "1";
}
ScalarKind::Uint => match scalar_type.width {
4 => {
zero = "0u";
one = "1u";
}
8 => {
zero = "0uL";
one = "1uL"
}
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
},
ScalarKind::Float => {
assert_eq!(
scalar_type.width, 4,
"Scalar {scalar_type:?} is not a result from any storage format"
);
zero = "0.0";
one = "1.0";
}
_ => unreachable!("Scalar {scalar_type:?} is not a result from any storage format"),
}

let ty = scalar_type.to_hlsl_str()?;
writeln!(
self.out,
"{ty}4 {IMAGE_STORAGE_LOAD_SCALAR_WRAPPER}{ty}({ty} {ARGUMENT_VARIABLE_NAME}) {{\
{ty}4 {RETURN_VARIABLE_NAME} = {ty}4({ARGUMENT_VARIABLE_NAME}, {zero}, {zero}, {one});\
return {RETURN_VARIABLE_NAME};\
}}"
)?;

Ok(())
}

pub(super) fn write_wrapped_struct_matrix_get_function_name(
&mut self,
access: WrappedStructMatrixAccess,
Expand Down Expand Up @@ -848,11 +904,12 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

/// Helper function that writes compose wrapped functions
pub(super) fn write_wrapped_compose_functions(
/// Helper function that writes wrapped functions for expressions in a function
pub(super) fn write_wrapped_expression_functions(
&mut self,
module: &crate::Module,
expressions: &crate::Arena<crate::Expression>,
context: Option<&FunctionCtx>,
) -> BackendResult {
for (handle, _) in expressions.iter() {
match expressions[handle] {
Expand All @@ -867,6 +924,23 @@ impl<W: Write> super::Writer<'_, W> {
_ => {}
};
}
crate::Expression::ImageLoad { image, .. } => {
// This can only happen in a function as this is not a valid const expression
match *context.as_ref().unwrap().resolve_type(image, &module.types) {
crate::TypeInner::Image {
class: crate::ImageClass::Storage { format, .. },
..
} => {
if format.single_component() {
let scalar: crate::Scalar = format.into();
if self.wrapped.image_load_scalars.insert(scalar) {
self.write_loaded_scalar_to_storage_loaded_value(scalar)?;
}
}
}
_ => {}
}
}
crate::Expression::RayQueryGetIntersection { committed, .. } => {
if committed {
if !self.written_committed_intersection {
Expand All @@ -884,7 +958,7 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}

// TODO: we could merge this with iteration in write_wrapped_compose_functions...
// TODO: we could merge this with iteration in write_wrapped_expression_functions...
//
/// Helper function that writes zero value wrapped functions
pub(super) fn write_wrapped_zero_value_functions(
Expand Down Expand Up @@ -1046,7 +1120,7 @@ impl<W: Write> super::Writer<'_, W> {
func_ctx: &FunctionCtx,
) -> BackendResult {
self.write_wrapped_math_functions(module, func_ctx)?;
self.write_wrapped_compose_functions(module, func_ctx.expressions)?;
self.write_wrapped_expression_functions(module, func_ctx.expressions, Some(func_ctx))?;
self.write_wrapped_zero_value_functions(module, func_ctx.expressions)?;

for (handle, _) in func_ctx.expressions.iter() {
Expand Down Expand Up @@ -1476,3 +1550,25 @@ impl<W: Write> super::Writer<'_, W> {
Ok(())
}
}

impl crate::StorageFormat {
/// Returns `true` if there is just one component, otherwise `false`
pub(super) const fn single_component(&self) -> bool {
match *self {
crate::StorageFormat::R16Float
| crate::StorageFormat::R32Float
| crate::StorageFormat::R8Unorm
| crate::StorageFormat::R16Unorm
| crate::StorageFormat::R8Snorm
| crate::StorageFormat::R16Snorm
| crate::StorageFormat::R8Uint
| crate::StorageFormat::R16Uint
| crate::StorageFormat::R32Uint
| crate::StorageFormat::R8Sint
| crate::StorageFormat::R16Sint
| crate::StorageFormat::R32Sint
| crate::StorageFormat::R64Uint => true,
_ => false,
}
}
}
5 changes: 4 additions & 1 deletion naga/src/back/hlsl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -908,4 +908,7 @@ pub const TYPES: &[&str] = &{
res
};

pub const RESERVED_PREFIXES: &[&str] = &["__dynamic_buffer_offsets"];
pub const RESERVED_PREFIXES: &[&str] = &[
"__dynamic_buffer_offsets",
super::help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
];
1 change: 1 addition & 0 deletions naga/src/back/hlsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ struct Wrapped {
zero_values: crate::FastHashSet<help::WrappedZeroValue>,
array_lengths: crate::FastHashSet<help::WrappedArrayLength>,
image_queries: crate::FastHashSet<help::WrappedImageQuery>,
image_load_scalars: crate::FastHashSet<crate::Scalar>,
constructors: crate::FastHashSet<help::WrappedConstructor>,
struct_matrix_access: crate::FastHashSet<help::WrappedStructMatrixAccess>,
mat_cx2s: crate::FastHashSet<help::WrappedMatCx2>,
Expand Down
27 changes: 26 additions & 1 deletion naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::{
help,
help::{
WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess,
WrappedZeroValue,
Expand Down Expand Up @@ -341,7 +342,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

self.write_special_functions(module)?;

self.write_wrapped_compose_functions(module, &module.global_expressions)?;
self.write_wrapped_expression_functions(module, &module.global_expressions, None)?;
self.write_wrapped_zero_value_functions(module, &module.global_expressions)?;

// Write all named constants
Expand Down Expand Up @@ -3152,6 +3153,26 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
sample,
level,
} => {
let mut wrapping_type = None;
match *func_ctx.resolve_type(image, &module.types) {
TypeInner::Image {
class: crate::ImageClass::Storage { format, .. },
..
} => {
if format.single_component() {
wrapping_type = Some(Scalar::from(format));
}
}
_ => {}
}
if let Some(scalar) = wrapping_type {
write!(
self.out,
"{}{}(",
help::IMAGE_STORAGE_LOAD_SCALAR_WRAPPER,
scalar.to_hlsl_str()?
)?;
}
// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/dx-graphics-hlsl-to-load
self.write_expr(module, image, func_ctx)?;
write!(self.out, ".Load(")?;
Expand All @@ -3173,6 +3194,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// close bracket for Load function
write!(self.out, ")")?;

if wrapping_type.is_some() {
write!(self.out, ")")?;
}

// return x component if return type is scalar
if let TypeInner::Scalar(_) = *func_ctx.resolve_type(expr, &module.types) {
write!(self.out, ".x")?;
Expand Down
3 changes: 2 additions & 1 deletion naga/tests/out/hlsl/storage-textures.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ RWTexture2D<float> s_r_w : register(u0, space1);
RWTexture2D<float4> s_rg_w : register(u1, space1);
RWTexture2D<float4> s_rgba_w : register(u2, space1);

float4 LoadedStorageValueFromfloat(float arg) {float4 ret = float4(arg, 0.0, 0.0, 1.0);return ret;}
[numthreads(1, 1, 1)]
void csLoad()
{
float4 phony = s_r_r.Load((0u).xx);
float4 phony = LoadedStorageValueFromfloat(s_r_r.Load((0u).xx));
float4 phony_1 = s_rg_r.Load((0u).xx);
float4 phony_2 = s_rgba_r.Load((0u).xx);
return;
Expand Down
128 changes: 125 additions & 3 deletions tests/tests/texture_binding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::time::Duration;
use wgpu::wgt::BufferDescriptor;
use wgpu::{
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, ComputePassDescriptor,
ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, TextureDescriptor,
TextureDimension, TextureFormat, TextureUsages,
include_wgsl, BindGroupDescriptor, BindGroupEntry, BindingResource, BufferUsages,
ComputePassDescriptor, ComputePipelineDescriptor, DownlevelFlags, Extent3d, Features, Maintain,
MapMode, Origin3d, TexelCopyBufferInfo, TexelCopyBufferLayout, TexelCopyTextureInfo,
TextureAspect, TextureDescriptor, TextureDimension, TextureFormat, TextureUsages,
};
use wgpu_macros::gpu_test;
use wgpu_test::{GpuTestConfiguration, TestParameters, TestingContext};
Expand Down Expand Up @@ -62,3 +65,122 @@ fn texture_binding(ctx: TestingContext) {
}
ctx.queue.submit([encoder.finish()]);
}

#[gpu_test]
static SINGLE_SCALAR_LOAD: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.downlevel_flags(DownlevelFlags::WEBGPU_TEXTURE_FORMAT_SUPPORT)
.features(Features::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES),
)
.run_sync(single_scalar_load);

fn single_scalar_load(ctx: TestingContext) {
let texture_read = ctx.device.create_texture(&TextureDescriptor {
label: None,
size: Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: TextureDimension::D2,
format: TextureFormat::R32Float,
usage: TextureUsages::STORAGE_BINDING,
view_formats: &[],
});
let texture_write = ctx.device.create_texture(&TextureDescriptor {
label: None,
size: Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
mip_level_count: 1,
sample_count: 1,
dimension: TextureDimension::D2,
format: TextureFormat::Rgba32Float,
usage: TextureUsages::STORAGE_BINDING | TextureUsages::COPY_SRC,
view_formats: &[],
});
let buffer = ctx.device.create_buffer(&BufferDescriptor {
label: None,
size: size_of::<[f32; 4]>() as wgpu::BufferAddress,
usage: BufferUsages::MAP_READ | BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let shader = ctx
.device
.create_shader_module(include_wgsl!("single_scalar.wgsl"));
let pipeline = ctx
.device
.create_compute_pipeline(&ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: None,
compilation_options: Default::default(),
cache: None,
});
let bind = ctx.device.create_bind_group(&BindGroupDescriptor {
label: None,
layout: &pipeline.get_bind_group_layout(0),
entries: &[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(
&texture_write.create_view(&Default::default()),
),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::TextureView(
&texture_read.create_view(&Default::default()),
),
},
],
});

let mut encoder = ctx.device.create_command_encoder(&Default::default());
{
let mut pass = encoder.begin_compute_pass(&ComputePassDescriptor::default());
pass.set_pipeline(&pipeline);
pass.set_bind_group(0, &bind, &[]);
pass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_texture_to_buffer(
TexelCopyTextureInfo {
texture: &texture_write,
mip_level: 0,
origin: Origin3d::ZERO,
aspect: TextureAspect::All,
},
TexelCopyBufferInfo {
buffer: &buffer,
layout: TexelCopyBufferLayout {
offset: 0,
bytes_per_row: None,
rows_per_image: None,
},
},
Extent3d {
width: 1,
height: 1,
depth_or_array_layers: 1,
},
);
ctx.queue.submit([encoder.finish()]);
let (send, recv) = std::sync::mpsc::channel();
buffer.slice(..).map_async(MapMode::Read, move |res| {
res.unwrap();
send.send(()).expect("Thread should wait for receive");
});
// Poll to run map.
ctx.device.poll(Maintain::Wait);
recv.recv_timeout(Duration::from_secs(10))
.expect("mapping should not take this long");
let val = *bytemuck::from_bytes::<[f32; 4]>(&buffer.slice(..).get_mapped_range());
assert_eq!(val, [0.0, 0.0, 0.0, 1.0]);
}
Loading

0 comments on commit 5af9e30

Please sign in to comment.