Skip to content

Commit

Permalink
refactor: a few tweaks to semantic module (#3057)
Browse files Browse the repository at this point in the history
  • Loading branch information
aljazerzen authored Jul 25, 2023
1 parent b612f09 commit 219d3ea
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 120 deletions.
16 changes: 14 additions & 2 deletions crates/prql_compiler/src/ir/pl/extra/expr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use enum_as_inner::EnumAsInner;
use prql_ast::expr::Literal;
use prql_ast::expr::{Ident, Literal};
use serde::{Deserialize, Serialize};

use crate::generic::WindowKind;
Expand Down Expand Up @@ -136,6 +136,18 @@ impl Expr {

impl From<Literal> for ExprKind {
fn from(value: Literal) -> Self {
Self::Literal(value)
ExprKind::Literal(value)
}
}

impl From<Ident> for ExprKind {
fn from(value: Ident) -> Self {
ExprKind::Ident(value)
}
}

impl From<Func> for ExprKind {
fn from(value: Func) -> Self {
ExprKind::Func(Box::new(value))
}
}
31 changes: 25 additions & 6 deletions crates/prql_compiler/src/semantic/ast_expand.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use prql_ast::expr::{BinaryExpr, Expr, ExprKind, Ident};
use prql_ast::expr::{BinOp, BinaryExpr, Expr, ExprKind, Ident};
use prql_ast::stmt::{Annotation, Stmt, StmtKind, VarDefKind};

use crate::ir::pl;
Expand Down Expand Up @@ -233,10 +233,17 @@ fn restrict_expr_kind(value: pl::ExprKind) -> ExprKind {
),
pl::ExprKind::Param(v) => ExprKind::Param(v),
pl::ExprKind::Internal(v) => ExprKind::Internal(v),
pl::ExprKind::All { .. }
| pl::ExprKind::TransformCall(_)
| pl::ExprKind::RqOperator { .. }
| pl::ExprKind::Type(_) => ExprKind::Ident(Ident::from_name("?")),

// TODO: these are not correct, they are producing invalid PRQL
pl::ExprKind::All { within, .. } => ExprKind::Ident(within),
pl::ExprKind::Type(ty) => ExprKind::Ident(Ident::from_name(format!("<{}>", ty))),
pl::ExprKind::TransformCall(tc) => ExprKind::Ident(Ident::from_name(format!(
"({} ...)",
tc.kind.as_ref().as_ref()
))),
pl::ExprKind::RqOperator { name, .. } => {
ExprKind::Ident(Ident::from_name(format!("({} ...)", name)))
}
}
}

Expand Down Expand Up @@ -265,7 +272,19 @@ fn restrict_ty(value: pl::Ty) -> prql_ast::expr::Expr {
ExprKind::Ident(Ident::from_path(vec!["std".to_string(), prim.to_string()]))
}
pl::TyKind::Singleton(lit) => ExprKind::Literal(lit),
pl::TyKind::Union(_) => todo!(),
pl::TyKind::Union(mut variants) => {
variants.reverse();
let mut res = restrict_ty(variants.pop().unwrap().1);
while let Some((_, ty)) = variants.pop() {
let ty = restrict_ty(ty);
res = Expr::new(ExprKind::Binary(BinaryExpr {
left: Box::new(res),
op: BinOp::Or,
right: Box::new(ty),
}));
}
return res;
}
pl::TyKind::Tuple(fields) => ExprKind::Tuple(
fields
.into_iter()
Expand Down
7 changes: 4 additions & 3 deletions crates/prql_compiler/src/semantic/resolver/context_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ impl Context {
// if ident.name != "*" {
// return Err("Unsupported feature: advanced wildcard column matching".to_string());
// }
return self
.resolve_ident_wildcard(ident)
.map_err(Error::new_simple);
return self.resolve_ident_wildcard(ident).map_err(|e| {
log::debug!("{:#?}", self.root_mod);
Error::new_simple(e)
});
}

// base case: direct lookup
Expand Down
212 changes: 103 additions & 109 deletions crates/prql_compiler/src/semantic/resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,15 +367,15 @@ impl PlFold for Resolver {

// fold function
let func = self.apply_args_to_closure(func, args, named_args)?;
self.fold_function(func, node.span)?
self.fold_function(func, span)?
}

ExprKind::Pipeline(pipeline) => {
self.default_namespace = None;
self.resolve_pipeline(pipeline)?
}

ExprKind::Func(closure) => self.fold_function(*closure, node.span)?,
ExprKind::Func(closure) => self.fold_function(*closure, span)?,

ExprKind::Unary(UnaryExpr {
op: UnOp::EqSelf,
Expand Down Expand Up @@ -629,9 +629,8 @@ impl Resolver {
// );
}

fn fold_function(&mut self, closure: Func, span: Option<Span>) -> Result<Expr, anyhow::Error> {
fn fold_function(&mut self, closure: Func, span: Option<Span>) -> Result<Expr> {
let closure = self.fold_function_types(closure)?;
let args_len = closure.args.len();

log::debug!(
"func {} {}/{} params",
Expand All @@ -650,126 +649,101 @@ impl Resolver {
}

let enough_args = closure.args.len() == closure.params.len();
if !enough_args {
// not enough arguments: don't fold
log::debug!("returning as closure");
return Ok(expr_of_func(closure));
}

let mut r = if enough_args {
// make sure named args are pushed into params
let closure = if !closure.named_params.is_empty() {
self.apply_args_to_closure(closure, [].into(), [].into())?
} else {
closure
};
// make sure named args are pushed into params
let closure = if !closure.named_params.is_empty() {
self.apply_args_to_closure(closure, [].into(), [].into())?
} else {
closure
};

// push the env
let closure_env = Module::from_exprs(closure.env);
self.context.root_mod.stack_push(NS_PARAM, closure_env);
let closure = Func {
env: HashMap::new(),
..closure
};
// push the env
let closure_env = Module::from_exprs(closure.env);
self.context.root_mod.stack_push(NS_PARAM, closure_env);
let closure = Func {
env: HashMap::new(),
..closure
};

if log::log_enabled!(log::Level::Debug) {
let name = closure
.name_hint
.clone()
.unwrap_or_else(|| Ident::from_name("<unnamed>"));
log::debug!("resolving args of function {}", name);
}
let closure = self.resolve_function_args(closure)?;

let needs_window = (closure.params.last())
.and_then(|p| p.ty.as_ref())
.map(|t| t.as_ty().unwrap().is_sub_type_of_array())
.unwrap_or_default();

// evaluate
let res = if let ExprKind::Internal(operator_name) = &closure.body.kind {
// special case: functions that have internal body

if operator_name.starts_with("std.") {
Expr {
ty: closure.return_ty.map(|t| t.into_ty().unwrap()),
needs_window,
..Expr::new(ExprKind::RqOperator {
name: operator_name.clone(),
args: closure.args,
})
}
} else {
let expr = transforms::cast_transform(self, closure)?;
self.fold_expr(expr)?
}
} else {
// base case: materialize
log::debug!("stack_push for {}", closure.as_debug_name());
if log::log_enabled!(log::Level::Debug) {
let name = closure
.name_hint
.clone()
.unwrap_or_else(|| Ident::from_name("<unnamed>"));
log::debug!("resolving args of function {}", name);
}
let closure = self.resolve_function_args(closure)?;

let (func_env, body) = env_of_closure(closure);
let needs_window = (closure.params.last())
.and_then(|p| p.ty.as_ref())
.map(|t| t.as_ty().unwrap().is_sub_type_of_array())
.unwrap_or_default();

self.context.root_mod.stack_push(NS_PARAM, func_env);
// evaluate
let res = if let ExprKind::Internal(operator_name) = &closure.body.kind {
// special case: functions that have internal body

// fold again, to resolve inner variables & functions
let body = self.fold_expr(body)?;
if operator_name.starts_with("std.") {
Expr {
ty: closure.return_ty.map(|t| t.into_ty().unwrap()),
needs_window,
..Expr::new(ExprKind::RqOperator {
name: operator_name.clone(),
args: closure.args,
})
}
} else {
let expr = transforms::cast_transform(self, closure)?;
self.fold_expr(expr)?
}
} else {
// base case: materialize
log::debug!("stack_push for {}", closure.as_debug_name());

// remove param decls
log::debug!("stack_pop: {:?}", body.id);
let func_env = self.context.root_mod.stack_pop(NS_PARAM).unwrap();
let (func_env, body) = env_of_closure(closure);

if let ExprKind::Func(mut inner_closure) = body.kind {
// body couldn't been resolved - construct a closure to be evaluated later
self.context.root_mod.stack_push(NS_PARAM, func_env);

inner_closure.env = func_env.into_exprs();
// fold again, to resolve inner variables & functions
let body = self.fold_expr(body)?;

let (got, missing) = inner_closure.params.split_at(inner_closure.args.len());
let missing = missing.to_vec();
inner_closure.params = got.to_vec();
// remove param decls
log::debug!("stack_pop: {:?}", body.id);
let func_env = self.context.root_mod.stack_pop(NS_PARAM).unwrap();

Expr::new(ExprKind::Func(Box::new(Func {
name_hint: None,
args: vec![],
params: missing,
named_params: vec![],
body: Box::new(Expr::new(ExprKind::Func(inner_closure))),
return_ty: None,
env: HashMap::new(),
})))
} else {
// resolved, return result
body
}
};
if let ExprKind::Func(mut inner_closure) = body.kind {
// body couldn't been resolved - construct a closure to be evaluated later

// pop the env
self.context.root_mod.stack_pop(NS_PARAM).unwrap();
inner_closure.env = func_env.into_exprs();

res
} else {
// not enough arguments: don't fold
log::debug!("returning as closure");
let (got, missing) = inner_closure.params.split_at(inner_closure.args.len());
let missing = missing.to_vec();
inner_closure.params = got.to_vec();

let ty = TyFunc {
args: closure
.params
.iter()
.skip(args_len)
.map(|a| a.ty.as_ref().map(|x| x.as_ty().cloned().unwrap()))
.collect(),
return_ty: Box::new(
closure
.return_ty
.as_ref()
.map(|x| x.as_ty().cloned().unwrap()),
),
};
Expr::new(ExprKind::Func(Box::new(Func {
name_hint: None,
args: vec![],
params: missing,
named_params: vec![],
body: Box::new(Expr::new(ExprKind::Func(inner_closure))),
return_ty: None,
env: HashMap::new(),
})))
} else {
// resolved, return result
body
}
};

let mut node = Expr::new(ExprKind::Func(Box::new(closure)));
node.ty = Some(Ty {
kind: TyKind::Function(ty),
name: None,
});
// pop the env
self.context.root_mod.stack_pop(NS_PARAM).unwrap();

node
};
r.span = span;
Ok(r)
Ok(res)
}

fn fold_function_types(&mut self, mut closure: Func) -> Result<Func> {
Expand Down Expand Up @@ -968,7 +942,7 @@ impl Resolver {
Ok(kind)
}

fn resolve_column_exclusion(&mut self, expr: Expr) -> Result<Expr, anyhow::Error> {
fn resolve_column_exclusion(&mut self, expr: Expr) -> Result<Expr> {
let expr = self.fold_expr(expr)?;
let tuple = coerce_into_tuple_and_flatten(expr)?;
let except: Vec<Expr> = tuple
Expand Down Expand Up @@ -1017,6 +991,26 @@ impl Resolver {
}
}

fn expr_of_func(func: Func) -> Expr {
let ty = TyFunc {
args: func
.params
.iter()
.skip(func.args.len())
.map(|a| a.ty.as_ref().map(|x| x.as_ty().cloned().unwrap()))
.collect(),
return_ty: Box::new(func.return_ty.as_ref().map(|x| x.as_ty().cloned().unwrap())),
};

Expr {
ty: Some(Ty {
kind: TyKind::Function(ty),
name: None,
}),
..Expr::new(ExprKind::Func(Box::new(func)))
}
}

fn ty_of_lineage(lineage: &Lineage) -> Ty {
Ty::relation(
lineage
Expand Down

0 comments on commit 219d3ea

Please sign in to comment.