Skip to content

Commit

Permalink
Fix nullable snapshots and a memory leak. (#415)
Browse files Browse the repository at this point in the history
* Add  override skeleton.

* Fix nullable snapshots.

* Fix memory leak when `unbox`ing.

* Add test.
  • Loading branch information
azteca1998 authored Jan 4, 2024
1 parent 9ec1790 commit 21d421b
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 16 deletions.
23 changes: 16 additions & 7 deletions src/libfuncs/box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,27 @@ where
let (inner_ty, inner_layout) =
registry.build_type_with_layout(context, helper, registry, metadata, &info.ty)?;

let op = entry.append_operation(llvm::load(
let value = entry
.append_operation(llvm::load(
context,
entry.argument(0)?.into(),
inner_ty,
location,
LoadStoreOptions::new().align(Some(IntegerAttribute::new(
inner_layout.align() as i64,
IntegerType::new(context, 64).into(),
))),
))
.result(0)?
.into();

entry.append_operation(ReallocBindingsMeta::free(
context,
entry.argument(0)?.into(),
inner_ty,
location,
LoadStoreOptions::new().align(Some(IntegerAttribute::new(
inner_layout.align() as i64,
IntegerType::new(context, 64).into(),
))),
));
entry.append_operation(helper.br(0, &[op.result(0)?.into()], location));

entry.append_operation(helper.br(0, &[value], location));
Ok(())
}

Expand Down
46 changes: 45 additions & 1 deletion src/libfuncs/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ where
#[cfg(test)]
mod test {
use crate::{
utils::test::{jit_struct, load_cairo, run_program_assert_output},
utils::test::{jit_enum, jit_struct, load_cairo, run_program_assert_output},
values::JitValue,
};

Expand Down Expand Up @@ -239,4 +239,48 @@ mod test {
run_program_assert_output(&program, "run_test", &[4u8.into()], 4u8.into());
run_program_assert_output(&program, "run_test", &[0u8.into()], 99u8.into());
}

#[test]
fn match_snapshot_nullable_clone_bug() {
let program = load_cairo! {
use core::{NullableTrait, match_nullable, null, nullable::FromNullableResult};

fn run_test(x: Option<u8>) -> Option<u8> {
let a = match x {
Option::Some(x) => @NullableTrait::new(x),
Option::None(_) => @null::<u8>(),
};
let b = *a;
match match_nullable(b) {
FromNullableResult::Null(_) => Option::None(()),
FromNullableResult::NotNull(x) => Option::Some(x.unbox()),
}
}
};

run_program_assert_output(
&program,
"run_test",
&[jit_enum!(0, 42u8.into())],
jit_enum!(0, 42u8.into()),
);
run_program_assert_output(
&program,
"run_test",
&[jit_enum!(
1,
JitValue::Struct {
fields: Vec::new(),
debug_name: None
}
)],
jit_enum!(
1,
JitValue::Struct {
fields: Vec::new(),
debug_name: None
}
),
);
}
}
167 changes: 159 additions & 8 deletions src/types/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,30 @@
use super::{TypeBuilder, WithSelf};
use crate::{
error::types::{Error, Result},
metadata::MetadataStorage,
error::{
libfuncs,
types::{Error, Result},
},
libfuncs::{LibfuncBuilder, LibfuncHelper},
metadata::{
realloc_bindings::ReallocBindingsMeta, snapshot_clones::SnapshotClonesMeta, MetadataStorage,
},
};
use cairo_lang_sierra::{
extensions::{types::InfoAndTypeConcreteType, GenericLibfunc, GenericType},
program_registry::ProgramRegistry,
};
use melior::{
dialect::llvm,
ir::{Module, Type},
dialect::{
arith::{self, CmpiPredicate},
llvm, scf,
},
ir::{
attribute::{IntegerAttribute, StringAttribute},
operation::OperationBuilder,
r#type::IntegerType,
Block, Location, Module, Region, Type, Value,
},
Context,
};

Expand All @@ -27,14 +41,151 @@ pub fn build<'ctx, TType, TLibfunc>(
context: &'ctx Context,
_module: &Module<'ctx>,
_registry: &ProgramRegistry<TType, TLibfunc>,
_metadata: &mut MetadataStorage,
_info: WithSelf<InfoAndTypeConcreteType>,
metadata: &mut MetadataStorage,
info: WithSelf<InfoAndTypeConcreteType>,
) -> Result<Type<'ctx>>
where
TType: GenericType,
TLibfunc: GenericLibfunc,
TType: 'static + GenericType,
TLibfunc: 'static + GenericLibfunc,
<TType as GenericType>::Concrete: TypeBuilder<TType, TLibfunc, Error = Error>,
<TLibfunc as GenericLibfunc>::Concrete:
LibfuncBuilder<TType, TLibfunc, Error = libfuncs::Error>,
{
metadata
.get_or_insert_with::<SnapshotClonesMeta<TType, TLibfunc>>(SnapshotClonesMeta::default)
.register(
info.self_ty().clone(),
snapshot_take,
InfoAndTypeConcreteType {
info: info.info.clone(),
ty: info.ty.clone(),
},
);

// nullable is represented as a pointer, like a box, used to check if its null (when it can be null).
Ok(llvm::r#type::opaque_pointer(context))
}

#[allow(clippy::too_many_arguments)]
fn snapshot_take<'ctx, 'this, TType, TLibfunc>(
context: &'ctx Context,
registry: &ProgramRegistry<TType, TLibfunc>,
entry: &'this Block<'ctx>,
location: Location<'ctx>,
helper: &LibfuncHelper<'ctx, 'this>,
metadata: &mut MetadataStorage,
info: WithSelf<InfoAndTypeConcreteType>,
src_value: Value<'ctx, 'this>,
) -> libfuncs::Result<Value<'ctx, 'this>>
where
TType: 'static + GenericType,
TLibfunc: 'static + GenericLibfunc,
<TType as GenericType>::Concrete: TypeBuilder<TType, TLibfunc, Error = Error>,
<TLibfunc as GenericLibfunc>::Concrete:
LibfuncBuilder<TType, TLibfunc, Error = libfuncs::Error>,
{
if metadata.get::<ReallocBindingsMeta>().is_none() {
metadata.insert(ReallocBindingsMeta::new(context, helper));
}

// let elem_snapshot_take = metadata
// .get::<SnapshotClonesMeta<TType, TLibfunc>>()
// .and_then(|meta| meta.wrap_invoke(&info.ty));

let elem_layout = registry.get_type(&info.ty)?.layout(registry)?;

let k0 = entry
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, IntegerType::new(context, 64).into()).into(),
location,
))
.result(0)?
.into();
let null_ptr = entry
.append_operation(llvm::nullptr(
llvm::r#type::opaque_pointer(context),
location,
))
.result(0)?
.into();

let ptr_value = entry
.append_operation(
OperationBuilder::new("llvm.ptrtoint", location)
.add_operands(&[src_value])
.add_results(&[IntegerType::new(context, 64).into()])
.build()?,
)
.result(0)?
.into();

let is_null = entry
.append_operation(arith::cmpi(
context,
CmpiPredicate::Eq,
ptr_value,
k0,
location,
))
.result(0)?
.into();
Ok(entry
.append_operation(scf::r#if(
is_null,
&[llvm::r#type::opaque_pointer(context)],
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));

block.append_operation(scf::r#yield(&[null_ptr], location));
region
},
{
let region = Region::new();
let block = region.append_block(Block::new(&[]));

let alloc_len = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(
elem_layout.size() as i64,
IntegerType::new(context, 64).into(),
)
.into(),
location,
))
.result(0)?
.into();

let cloned_ptr = block
.append_operation(ReallocBindingsMeta::realloc(
context, null_ptr, alloc_len, location,
))
.result(0)?
.into();

let is_volatile = block
.append_operation(arith::constant(
context,
IntegerAttribute::new(0, IntegerType::new(context, 1).into()).into(),
location,
))
.result(0)?
.into();
block.append_operation(llvm::call_intrinsic(
context,
StringAttribute::new(context, "llvm.memcpy.inline"),
&[cloned_ptr, src_value, alloc_len, is_volatile],
&[],
location,
));

block.append_operation(scf::r#yield(&[cloned_ptr], location));
region
},
location,
))
.result(0)?
.into())
}

0 comments on commit 21d421b

Please sign in to comment.