Skip to content

Commit

Permalink
feat: func type (#3058)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Jul 25, 2023
1 parent 219d3ea commit 4863904
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 38 deletions.
3 changes: 2 additions & 1 deletion crates/prql_compiler/src/codegen/pl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ impl WriteSource for TyKind {
.write_between("{", "}", opt),
Set => Some("set".to_string()),
Array(elem) => Some(format!("[{}]", elem.write(opt)?)),
Function(func) => {
Function(None) => Some("func".to_string()),
Function(Some(func)) => {
let mut r = String::new();

for t in &func.args {
Expand Down
6 changes: 4 additions & 2 deletions crates/prql_compiler/src/ir/pl/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub enum TyKind {
Set,

/// Type of functions with defined params and return types.
Function(TyFunc),
Function(Option<TyFunc>),
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, EnumAsInner)]
Expand Down Expand Up @@ -146,7 +146,9 @@ impl TyKind {
many.iter().any(|(_, any)| any.kind.is_super_type_of(one))
}

(TyKind::Function(sup), TyKind::Function(sub)) => {
(TyKind::Function(None), TyKind::Function(_)) => true,
(TyKind::Function(Some(_)), TyKind::Function(None)) => true,
(TyKind::Function(Some(sup)), TyKind::Function(Some(sub))) => {
if is_not_super_type_of(sup.return_ty.as_ref(), sub.return_ty.as_ref()) {
return false;
}
Expand Down
8 changes: 7 additions & 1 deletion crates/prql_compiler/src/semantic/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ fn expr_of_func(func: Func) -> Expr {

Expr {
ty: Some(Ty {
kind: TyKind::Function(ty),
kind: TyKind::Function(Some(ty)),
name: None,
}),
..Expr::new(ExprKind::Func(Box::new(func)))
Expand Down Expand Up @@ -1055,6 +1055,12 @@ fn get_stdlib_decl(name: &str) -> Option<ExprKind> {
"date" => PrimitiveSet::Date,
"time" => PrimitiveSet::Time,
"timestamp" => PrimitiveSet::Timestamp,
"func" => {
return Some(ExprKind::Type(Ty {
kind: TyKind::Function(None),
name: None,
}))
}
_ => return None,
};
Some(ExprKind::Type(Ty {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
source: prql-compiler/src/semantic/resolver.rs
source: crates/prql_compiler/src/semantic/resolver/mod.rs
expression: "resolve_lineage(r#\"\n from table_1\n join customers (==customer_no)\n \"#).unwrap()"
---
columns:
Expand All @@ -10,12 +10,12 @@ columns:
input_name: customers
except: []
inputs:
- id: 180
- id: 182
name: table_1
table:
- default_db
- table_1
- id: 192
- id: 194
name: customers
table:
- default_db
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
---
source: prql-compiler/src/semantic/resolver.rs
source: crates/prql_compiler/src/semantic/resolver/mod.rs
expression: "resolve_lineage(r#\"\n from e = employees\n join salaries (==emp_no)\n group {e.emp_no, e.gender} (\n aggregate {\n emp_salary = average salaries.salary\n }\n )\n \"#).unwrap()"
---
columns:
- Single:
name:
- e
- emp_no
target_id: 219
target_id: 221
target_name: ~
- Single:
name:
- e
- gender
target_id: 220
target_id: 222
target_name: ~
- Single:
name:
- emp_salary
target_id: 245
target_id: 247
target_name: ~
inputs:
- id: 180
- id: 182
name: e
table:
- default_db
- employees
- id: 214
- id: 216
name: salaries
table:
- default_db
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
---
source: prql-compiler/src/semantic/resolver.rs
source: crates/prql_compiler/src/semantic/resolver/mod.rs
expression: "resolve_lineage(r#\"\n from orders\n select {customer_no, gross, tax, gross - tax}\n take 20\n \"#).unwrap()"
---
columns:
- Single:
name:
- orders
- customer_no
target_id: 210
target_id: 212
target_name: ~
- Single:
name:
- orders
- gross
target_id: 211
target_id: 213
target_name: ~
- Single:
name:
- orders
- tax
target_id: 212
target_id: 214
target_name: ~
- Single:
name: ~
target_id: 214
target_id: 216
target_name: ~
inputs:
- id: 180
- id: 182
name: orders
table:
- default_db
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
source: prql-compiler/src/semantic/transforms.rs
source: crates/prql_compiler/src/semantic/resolver/transforms.rs
expression: expr
---
TransformCall:
Expand All @@ -21,7 +21,7 @@ TransformCall:
input_name: c_invoice
except: []
inputs:
- id: 180
- id: 182
name: c_invoice
table:
- default_db
Expand Down Expand Up @@ -185,14 +185,14 @@ lineage:
name:
- c_invoice
- issued_at
target_id: 203
target_id: 205
target_name: ~
- Single:
name: ~
target_id: 227
target_id: 229
target_name: ~
inputs:
- id: 180
- id: 182
name: c_invoice
table:
- default_db
Expand Down
4 changes: 2 additions & 2 deletions crates/prql_compiler/src/semantic/resolver/type_resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ fn coerce_kind_to_set(resolver: &mut Resolver, expr: ExprKind) -> Result<Ty> {
// functions
ExprKind::Func(func) => Ty {
name: None,
kind: TyKind::Function(TyFunc {
kind: TyKind::Function(Some(TyFunc {
args: func
.params
.into_iter()
.map(|p| p.ty.map(|t| t.into_ty().unwrap()))
.collect_vec(),
return_ty: Box::new(resolver.fold_type_expr(Some(func.body))?),
}),
})),
},

_ => {
Expand Down
16 changes: 10 additions & 6 deletions crates/prql_compiler/src/semantic/std.prql
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ type text
type date
type time
type timestamp
type `func`

## Generic array
# TODO: an array of anything, not just nulls
Expand Down Expand Up @@ -133,7 +134,10 @@ let remove = `default_db.bottom`<relation> top<relation> -> <relation> (
filter (tuple_every (tuple_map _is_null b.*))
select t.*
)
let loop = pipeline top<relation> -> <relation> internal loop
let loop = func
pipeline <transform>
top <relation>
-> <relation> internal loop

## Aggregate functions
# These return either a scalar when used within `aggregate`, or a column when used anywhere else.
Expand Down Expand Up @@ -177,11 +181,11 @@ let as = `noresolve.type` column -> <scalar> internal std.as
let in = pattern value -> <bool> internal in

## Tuple functions
let tuple_every = list -> <bool> internal tuple_every
let tuple_map = fn list -> internal tuple_map
let tuple_zip = a b -> internal tuple_zip
let _eq = a -> internal _eq
let _is_null = a -> _param.a == null
let tuple_every = func list -> <bool> internal tuple_every
let tuple_map = func fn <func> list -> internal tuple_map
let tuple_zip = func a b -> internal tuple_zip
let _eq = func a -> internal _eq
let _is_null = func a -> _param.a == null

## Misc
let from_text = input<text> `noresolve.format`:csv -> <relation> internal from_text
Expand Down
8 changes: 3 additions & 5 deletions crates/prql_compiler/src/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,16 @@ fn test_append() {

assert_display_snapshot!(compile(r###"
from employees
derive {name, cost = salary}
select {name, cost = salary}
take 3
append (
from employees
derive {name, cost = salary + bonuses}
select {name, cost = salary + bonuses}
take 10
)
"###).unwrap(), @r###"
WITH table_0 AS (
SELECT
*,
name,
salary + bonuses AS cost
FROM
Expand All @@ -202,7 +201,6 @@ fn test_append() {
FROM
(
SELECT
*,
name,
salary AS cost
FROM
Expand Down Expand Up @@ -3042,7 +3040,7 @@ fn test_closures_and_pipelines() {
assert_display_snapshot!(compile(
r###"
let addthree = a b c -> s"{a} || {b} || {c}"
let arg = myarg myfunc -> ( myfunc myarg )
let arg = myarg myfunc <func> -> ( myfunc myarg )
from y
select x = (
Expand Down
4 changes: 3 additions & 1 deletion crates/prql_parser/src/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ pub fn type_expr() -> impl Parser<Token, Expr, Error = PError> {

let ident = ident().map(ExprKind::Ident);

let term = literal.or(ident).map_with_span(into_expr);
let func = keyword("func").to(ExprKind::Ident(Ident::from_path(vec!["std", "func"])));

let term = literal.or(ident).or(func).map_with_span(into_expr);

binary_op_parser(term, operator_or())
.delimited_by(ctrl('<'), ctrl('>'))
Expand Down

0 comments on commit 4863904

Please sign in to comment.