From d4d5612645fac495164fee8b53ebc0d2fd46dbf4 Mon Sep 17 00:00:00 2001 From: Frank Bossen Date: Sun, 17 Mar 2024 12:53:16 -0400 Subject: [PATCH] Port C code to Rust --- src/decode.rs | 58 +++++++++++++++++---------------- src/intra_edge.rs | 83 ++++++++++++++--------------------------------- 2 files changed, 54 insertions(+), 87 deletions(-) diff --git a/src/decode.rs b/src/decode.rs index fcd0d75e4..b9d3fafe4 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -3575,16 +3575,16 @@ unsafe fn decode_sb( None => { let tip = intra_edge.tip(sb128, edge_index); assert!(hsz == 1); - decode_b(c, t, f, bl, BS_4x4, bp, tip.split[0])?; + decode_b(c, t, f, bl, BS_4x4, bp, EdgeFlags::EDGE_ALL_TR_AND_BL)?; let tl_filter = t.tl_4x4_filter; t.bx += 1; - decode_b(c, t, f, bl, BS_4x4, bp, tip.split[1])?; + decode_b(c, t, f, bl, BS_4x4, bp, tip.split[0])?; t.bx -= 1; t.by += 1; - decode_b(c, t, f, bl, BS_4x4, bp, tip.split[2])?; + decode_b(c, t, f, bl, BS_4x4, bp, tip.split[1])?; t.bx += 1; t.tl_4x4_filter = tl_filter; - decode_b(c, t, f, bl, BS_4x4, bp, tip.split[3])?; + decode_b(c, t, f, bl, BS_4x4, bp, tip.split[2])?; t.bx -= 1; t.by -= 1; if cfg!(target_arch = "x86_64") && t.frame_thread.pass != 0 { @@ -3612,68 +3612,70 @@ unsafe fn decode_sb( } } PARTITION_T_TOP_SPLIT => { - let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.tts[0])?; + let node = intra_edge.node(sb128, edge_index); + decode_b(c, t, f, bl, b[0], bp, EdgeFlags::EDGE_ALL_TR_AND_BL)?; t.bx += hsz; - decode_b(c, t, f, bl, b[0], bp, branch.tts[1])?; + decode_b(c, t, f, bl, b[0], bp, node.v[1])?; t.bx -= hsz; t.by += hsz; - decode_b(c, t, f, bl, b[1], bp, branch.tts[2])?; + decode_b(c, t, f, bl, b[1], bp, node.h[1])?; t.by -= hsz; } PARTITION_T_BOTTOM_SPLIT => { - let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.tbs[0])?; + let node = intra_edge.node(sb128, edge_index); + decode_b(c, t, f, bl, b[0], bp, node.h[0])?; t.by += hsz; - decode_b(c, t, f, bl, b[1], bp, branch.tbs[1])?; + decode_b(c, t, f, bl, b[1], bp, node.v[0])?; t.bx += hsz; - decode_b(c, t, f, bl, b[1], bp, branch.tbs[2])?; + decode_b(c, t, f, bl, b[1], bp, EdgeFlags::EDGE_NONE)?; t.bx -= hsz; t.by -= hsz; } PARTITION_T_LEFT_SPLIT => { - let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.tls[0])?; + let node = intra_edge.node(sb128, edge_index); + decode_b(c, t, f, bl, b[0], bp, EdgeFlags::EDGE_ALL_TR_AND_BL)?; t.by += hsz; - decode_b(c, t, f, bl, b[0], bp, branch.tls[1])?; + decode_b(c, t, f, bl, b[0], bp, node.h[1])?; t.by -= hsz; t.bx += hsz; - decode_b(c, t, f, bl, b[1], bp, branch.tls[2])?; + decode_b(c, t, f, bl, b[1], bp, node.v[1])?; t.bx -= hsz; } PARTITION_T_RIGHT_SPLIT => { - let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.trs[0])?; + let node = intra_edge.node(sb128, edge_index); + decode_b(c, t, f, bl, b[0], bp, node.v[0])?; t.bx += hsz; - decode_b(c, t, f, bl, b[1], bp, branch.trs[1])?; + decode_b(c, t, f, bl, b[1], bp, node.h[0])?; t.by += hsz; - decode_b(c, t, f, bl, b[1], bp, (*branch).trs[2])?; + decode_b(c, t, f, bl, b[1], bp, EdgeFlags::EDGE_NONE)?; t.by -= hsz; t.bx -= hsz; } PARTITION_H4 => { + let node = intra_edge.node(sb128, edge_index); let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.h4[0])?; + decode_b(c, t, f, bl, b[0], bp, node.h[0])?; t.by += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.h4[1])?; + decode_b(c, t, f, bl, b[0], bp, branch.h4)?; t.by += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.h4[2])?; + decode_b(c, t, f, bl, b[0], bp, EdgeFlags::LEFT_HAS_BOTTOM)?; t.by += hsz >> 1; if t.by < f.bh { - decode_b(c, t, f, bl, b[0], bp, branch.h4[3])?; + decode_b(c, t, f, bl, b[0], bp, node.h[1])?; } t.by -= hsz * 3 >> 1; } PARTITION_V4 => { + let node = intra_edge.node(sb128, edge_index); let branch = intra_edge.branch(sb128, edge_index); - decode_b(c, t, f, bl, b[0], bp, branch.v4[0])?; + decode_b(c, t, f, bl, b[0], bp, node.v[0])?; t.bx += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.v4[1])?; + decode_b(c, t, f, bl, b[0], bp, branch.v4)?; t.bx += hsz >> 1; - decode_b(c, t, f, bl, b[0], bp, branch.v4[2])?; + decode_b(c, t, f, bl, b[0], bp, EdgeFlags::TOP_HAS_RIGHT)?; t.bx += hsz >> 1; if t.bx < f.bw { - decode_b(c, t, f, bl, b[0], bp, branch.v4[3])?; + decode_b(c, t, f, bl, b[0], bp, node.v[1])?; } t.bx -= hsz * 3 >> 1; } diff --git a/src/intra_edge.rs b/src/intra_edge.rs index 442ebe439..aeebcd93b 100644 --- a/src/intra_edge.rs +++ b/src/intra_edge.rs @@ -18,18 +18,23 @@ bitflags! { } impl EdgeFlags { - const LEFT_HAS_BOTTOM: Self = Self::union_all([ + pub const LEFT_HAS_BOTTOM: Self = Self::union_all([ Self::I444_LEFT_HAS_BOTTOM, Self::I422_LEFT_HAS_BOTTOM, Self::I420_LEFT_HAS_BOTTOM, ]); - const TOP_HAS_RIGHT: Self = Self::union_all([ + pub const TOP_HAS_RIGHT: Self = Self::union_all([ Self::I444_TOP_HAS_RIGHT, Self::I422_TOP_HAS_RIGHT, Self::I420_TOP_HAS_RIGHT, ]); + pub const EDGE_ALL_TR_AND_BL: Self = + Self::union_all([Self::LEFT_HAS_BOTTOM, Self::TOP_HAS_RIGHT]); + + pub const EDGE_NONE: Self = Self::empty(); + pub const fn union_all(flags: [Self; N]) -> Self { let mut i = 0; let mut output = Self::empty(); @@ -99,19 +104,15 @@ pub struct EdgeNode { #[repr(C)] pub struct EdgeTip { pub node: EdgeNode, - pub split: [EdgeFlags; B], + pub split: [EdgeFlags; 3], } #[repr(C)] pub struct EdgeBranch { pub node: EdgeNode, - pub tts: [EdgeFlags; 3], - pub tbs: [EdgeFlags; 3], - pub tls: [EdgeFlags; 3], - pub trs: [EdgeFlags; 3], - pub h4: [EdgeFlags; 4], - pub v4: [EdgeFlags; 4], - pub split: [EdgeIndex; B], + pub h4: EdgeFlags, + pub v4: EdgeFlags, + pub split: [EdgeIndex; 4], } impl EdgeTip { @@ -135,7 +136,6 @@ impl EdgeTip { let node = EdgeNode { o, h, v }; let split = [ - EdgeFlags::all(), edge_flags .intersection(EdgeFlags::TOP_HAS_RIGHT) .union(EdgeFlags::I422_LEFT_HAS_BOTTOM), @@ -164,51 +164,20 @@ impl EdgeBranch { ]; let node = EdgeNode { o, h, v }; - let h4 = [ - edge_flags.union(EdgeFlags::LEFT_HAS_BOTTOM), - EdgeFlags::LEFT_HAS_BOTTOM.union( - edge_flags - .intersection(EdgeFlags::I420_TOP_HAS_RIGHT) - .select(matches!(bl, BlockLevel::Bl16x16)), - ), - EdgeFlags::LEFT_HAS_BOTTOM, - edge_flags.intersection(EdgeFlags::LEFT_HAS_BOTTOM), - ]; - - let v4 = [ - edge_flags.union(EdgeFlags::TOP_HAS_RIGHT), - EdgeFlags::TOP_HAS_RIGHT.union( - edge_flags - .intersection(EdgeFlags::union_all([ - EdgeFlags::I420_LEFT_HAS_BOTTOM, - EdgeFlags::I422_LEFT_HAS_BOTTOM, - ])) - .select(matches!(bl, BlockLevel::Bl16x16)), - ), - EdgeFlags::TOP_HAS_RIGHT, - edge_flags.intersection(EdgeFlags::TOP_HAS_RIGHT), - ]; + let h4 = EdgeFlags::LEFT_HAS_BOTTOM.union( + edge_flags + .intersection(EdgeFlags::I420_TOP_HAS_RIGHT) + .select(matches!(bl, BlockLevel::Bl16x16)), + ); - let tls = [ - EdgeFlags::all(), - edge_flags.intersection(EdgeFlags::LEFT_HAS_BOTTOM), - edge_flags.intersection(EdgeFlags::TOP_HAS_RIGHT), - ]; - let trs = [ - edge_flags.union(EdgeFlags::TOP_HAS_RIGHT), - edge_flags.union(EdgeFlags::LEFT_HAS_BOTTOM), - EdgeFlags::empty(), - ]; - let tts = [ - EdgeFlags::all(), - edge_flags.intersection(EdgeFlags::TOP_HAS_RIGHT), - edge_flags.intersection(EdgeFlags::LEFT_HAS_BOTTOM), - ]; - let tbs = [ - edge_flags.union(EdgeFlags::LEFT_HAS_BOTTOM), - edge_flags.union(EdgeFlags::TOP_HAS_RIGHT), - EdgeFlags::empty(), - ]; + let v4 = EdgeFlags::TOP_HAS_RIGHT.union( + edge_flags + .intersection(EdgeFlags::union_all([ + EdgeFlags::I420_LEFT_HAS_BOTTOM, + EdgeFlags::I422_LEFT_HAS_BOTTOM, + ])) + .select(matches!(bl, BlockLevel::Bl16x16)), + ); let split = [EdgeIndex::root(); 4]; @@ -216,10 +185,6 @@ impl EdgeBranch { node, h4, v4, - tls, - trs, - tts, - tbs, split, } }