diff --git a/crates/prql_compiler/src/codegen/pl.rs b/crates/prql_compiler/src/codegen/pl.rs index 88b0e55f740c..bca4c267d15a 100644 --- a/crates/prql_compiler/src/codegen/pl.rs +++ b/crates/prql_compiler/src/codegen/pl.rs @@ -61,7 +61,8 @@ impl WriteSource for TyKind { .write_between("{", "}", opt), Set => Some("set".to_string()), Array(elem) => Some(format!("[{}]", elem.write(opt)?)), - Function(func) => { + Function(None) => Some("func".to_string()), + Function(Some(func)) => { let mut r = String::new(); for t in &func.args { diff --git a/crates/prql_compiler/src/ir/pl/types.rs b/crates/prql_compiler/src/ir/pl/types.rs index c8c18166ede0..f6853f2a0176 100644 --- a/crates/prql_compiler/src/ir/pl/types.rs +++ b/crates/prql_compiler/src/ir/pl/types.rs @@ -27,7 +27,7 @@ pub enum TyKind { Set, /// Type of functions with defined params and return types. - Function(TyFunc), + Function(Option), } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, EnumAsInner)] @@ -146,7 +146,9 @@ impl TyKind { many.iter().any(|(_, any)| any.kind.is_super_type_of(one)) } - (TyKind::Function(sup), TyKind::Function(sub)) => { + (TyKind::Function(None), TyKind::Function(_)) => true, + (TyKind::Function(Some(_)), TyKind::Function(None)) => true, + (TyKind::Function(Some(sup)), TyKind::Function(Some(sub))) => { if is_not_super_type_of(sup.return_ty.as_ref(), sub.return_ty.as_ref()) { return false; } diff --git a/crates/prql_compiler/src/semantic/resolver/mod.rs b/crates/prql_compiler/src/semantic/resolver/mod.rs index 0d6177825701..19be15c1c083 100644 --- a/crates/prql_compiler/src/semantic/resolver/mod.rs +++ b/crates/prql_compiler/src/semantic/resolver/mod.rs @@ -1004,7 +1004,7 @@ fn expr_of_func(func: Func) -> Expr { Expr { ty: Some(Ty { - kind: TyKind::Function(ty), + kind: TyKind::Function(Some(ty)), name: None, }), ..Expr::new(ExprKind::Func(Box::new(func))) @@ -1055,6 +1055,12 @@ fn get_stdlib_decl(name: &str) -> Option { "date" => PrimitiveSet::Date, "time" => PrimitiveSet::Time, "timestamp" => PrimitiveSet::Timestamp, + "func" => { + return Some(ExprKind::Type(Ty { + kind: TyKind::Function(None), + name: None, + })) + } _ => return None, }; Some(ExprKind::Type(Ty { diff --git a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-2.snap b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-2.snap index 5ddc45048a2a..66aaca33d72c 100644 --- a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-2.snap +++ b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-2.snap @@ -1,5 +1,5 @@ --- -source: prql-compiler/src/semantic/resolver.rs +source: crates/prql_compiler/src/semantic/resolver/mod.rs expression: "resolve_lineage(r#\"\n from table_1\n join customers (==customer_no)\n \"#).unwrap()" --- columns: @@ -10,12 +10,12 @@ columns: input_name: customers except: [] inputs: - - id: 180 + - id: 182 name: table_1 table: - default_db - table_1 - - id: 192 + - id: 194 name: customers table: - default_db diff --git a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-3.snap b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-3.snap index a4eadb9bc9d3..a29db587a9eb 100644 --- a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-3.snap +++ b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names-3.snap @@ -1,5 +1,5 @@ --- -source: prql-compiler/src/semantic/resolver.rs +source: crates/prql_compiler/src/semantic/resolver/mod.rs expression: "resolve_lineage(r#\"\n from e = employees\n join salaries (==emp_no)\n group {e.emp_no, e.gender} (\n aggregate {\n emp_salary = average salaries.salary\n }\n )\n \"#).unwrap()" --- columns: @@ -7,26 +7,26 @@ columns: name: - e - emp_no - target_id: 219 + target_id: 221 target_name: ~ - Single: name: - e - gender - target_id: 220 + target_id: 222 target_name: ~ - Single: name: - emp_salary - target_id: 245 + target_id: 247 target_name: ~ inputs: - - id: 180 + - id: 182 name: e table: - default_db - employees - - id: 214 + - id: 216 name: salaries table: - default_db diff --git a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names.snap b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names.snap index 5a0ab79428da..5578c934b079 100644 --- a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names.snap +++ b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__test__frames_and_names.snap @@ -1,5 +1,5 @@ --- -source: prql-compiler/src/semantic/resolver.rs +source: crates/prql_compiler/src/semantic/resolver/mod.rs expression: "resolve_lineage(r#\"\n from orders\n select {customer_no, gross, tax, gross - tax}\n take 20\n \"#).unwrap()" --- columns: @@ -7,26 +7,26 @@ columns: name: - orders - customer_no - target_id: 210 + target_id: 212 target_name: ~ - Single: name: - orders - gross - target_id: 211 + target_id: 213 target_name: ~ - Single: name: - orders - tax - target_id: 212 + target_id: 214 target_name: ~ - Single: name: ~ - target_id: 214 + target_id: 216 target_name: ~ inputs: - - id: 180 + - id: 182 name: orders table: - default_db diff --git a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap index bff1b17ebcde..b3f8729d62a1 100644 --- a/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap +++ b/crates/prql_compiler/src/semantic/resolver/snapshots/prql_compiler__semantic__resolver__transforms__tests__aggregate_positional_arg-2.snap @@ -1,5 +1,5 @@ --- -source: prql-compiler/src/semantic/transforms.rs +source: crates/prql_compiler/src/semantic/resolver/transforms.rs expression: expr --- TransformCall: @@ -21,7 +21,7 @@ TransformCall: input_name: c_invoice except: [] inputs: - - id: 180 + - id: 182 name: c_invoice table: - default_db @@ -185,14 +185,14 @@ lineage: name: - c_invoice - issued_at - target_id: 203 + target_id: 205 target_name: ~ - Single: name: ~ - target_id: 227 + target_id: 229 target_name: ~ inputs: - - id: 180 + - id: 182 name: c_invoice table: - default_db diff --git a/crates/prql_compiler/src/semantic/resolver/type_resolver.rs b/crates/prql_compiler/src/semantic/resolver/type_resolver.rs index dddb54e8a8cf..d437276b21d7 100644 --- a/crates/prql_compiler/src/semantic/resolver/type_resolver.rs +++ b/crates/prql_compiler/src/semantic/resolver/type_resolver.rs @@ -107,14 +107,14 @@ fn coerce_kind_to_set(resolver: &mut Resolver, expr: ExprKind) -> Result { // functions ExprKind::Func(func) => Ty { name: None, - kind: TyKind::Function(TyFunc { + kind: TyKind::Function(Some(TyFunc { args: func .params .into_iter() .map(|p| p.ty.map(|t| t.into_ty().unwrap())) .collect_vec(), return_ty: Box::new(resolver.fold_type_expr(Some(func.body))?), - }), + })), }, _ => { diff --git a/crates/prql_compiler/src/semantic/std.prql b/crates/prql_compiler/src/semantic/std.prql index 31c7b02b1c1c..7006bb515c26 100644 --- a/crates/prql_compiler/src/semantic/std.prql +++ b/crates/prql_compiler/src/semantic/std.prql @@ -45,6 +45,7 @@ type text type date type time type timestamp +type `func` ## Generic array # TODO: an array of anything, not just nulls @@ -133,7 +134,10 @@ let remove = `default_db.bottom` top -> ( filter (tuple_every (tuple_map _is_null b.*)) select t.* ) -let loop = pipeline top -> internal loop +let loop = func + pipeline + top + -> internal loop ## Aggregate functions # These return either a scalar when used within `aggregate`, or a column when used anywhere else. @@ -177,11 +181,11 @@ let as = `noresolve.type` column -> internal std.as let in = pattern value -> internal in ## Tuple functions -let tuple_every = list -> internal tuple_every -let tuple_map = fn list -> internal tuple_map -let tuple_zip = a b -> internal tuple_zip -let _eq = a -> internal _eq -let _is_null = a -> _param.a == null +let tuple_every = func list -> internal tuple_every +let tuple_map = func fn list -> internal tuple_map +let tuple_zip = func a b -> internal tuple_zip +let _eq = func a -> internal _eq +let _is_null = func a -> _param.a == null ## Misc let from_text = input `noresolve.format`:csv -> internal from_text diff --git a/crates/prql_compiler/src/tests/test.rs b/crates/prql_compiler/src/tests/test.rs index bbaa13e810c0..0d96e07705ad 100644 --- a/crates/prql_compiler/src/tests/test.rs +++ b/crates/prql_compiler/src/tests/test.rs @@ -179,17 +179,16 @@ fn test_append() { assert_display_snapshot!(compile(r###" from employees - derive {name, cost = salary} + select {name, cost = salary} take 3 append ( from employees - derive {name, cost = salary + bonuses} + select {name, cost = salary + bonuses} take 10 ) "###).unwrap(), @r###" WITH table_0 AS ( SELECT - *, name, salary + bonuses AS cost FROM @@ -202,7 +201,6 @@ fn test_append() { FROM ( SELECT - *, name, salary AS cost FROM @@ -3042,7 +3040,7 @@ fn test_closures_and_pipelines() { assert_display_snapshot!(compile( r###" let addthree = a b c -> s"{a} || {b} || {c}" - let arg = myarg myfunc -> ( myfunc myarg ) + let arg = myarg myfunc -> ( myfunc myarg ) from y select x = ( diff --git a/crates/prql_parser/src/stmt.rs b/crates/prql_parser/src/stmt.rs index 4308bc032428..28f8f83b8758 100644 --- a/crates/prql_parser/src/stmt.rs +++ b/crates/prql_parser/src/stmt.rs @@ -146,7 +146,9 @@ pub fn type_expr() -> impl Parser { let ident = ident().map(ExprKind::Ident); - let term = literal.or(ident).map_with_span(into_expr); + let func = keyword("func").to(ExprKind::Ident(Ident::from_path(vec!["std", "func"]))); + + let term = literal.or(ident).or(func).map_with_span(into_expr); binary_op_parser(term, operator_or()) .delimited_by(ctrl('<'), ctrl('>'))