Skip to content

Commit

Permalink
feat: support keyword completion (#855)
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <[email protected]>

Support keyword completion in the builtin shell. Press tab to try it
out!

---------

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Nov 23, 2024
1 parent 8b19a09 commit a599403
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
106 changes: 106 additions & 0 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,11 @@ impl Database {
.add_row_count(table_id, count);
Ok(true)
}

/// Return all available pragma options.
fn pragma_options() -> &'static [&'static str] {
&["enable_optimizer", "disable_optimizer"]
}
}

/// The error type of database operations.
Expand Down Expand Up @@ -230,3 +235,104 @@ pub enum Error {
#[error("Internal error: {0}")]
Internal(String),
}

impl rustyline::Helper for &Database {}
impl rustyline::validate::Validator for &Database {}
impl rustyline::highlight::Highlighter for &Database {}
impl rustyline::hint::Hinter for &Database {
type Hint = String;
}

/// Implement SQL completion.
impl rustyline::completion::Completer for &Database {
type Candidate = rustyline::completion::Pair;
fn complete(
&self,
line: &str,
pos: usize,
_ctx: &rustyline::Context<'_>,
) -> rustyline::Result<(usize, Vec<Self::Candidate>)> {
// find the word before cursor
let (prefix, last_word) = line[..pos].rsplit_once(' ').unwrap_or(("", &line[..pos]));

// completion for pragma options
if prefix.trim().eq_ignore_ascii_case("pragma") {
let candidates = Database::pragma_options()
.iter()
.filter(|option| option.starts_with(last_word))
.map(|option| rustyline::completion::Pair {
display: option.to_string(),
replacement: option.to_string(),
})
.collect();
return Ok((pos - last_word.len(), candidates));
}

// TODO: complete table and column names

// completion for keywords

// for a given prefix, all keywords starting with the prefix are returned as candidates
// they should be ordered in principle that frequently used ones come first
const KEYWORDS: &[&str] = &[
"AS", "ALL", "ANALYZE", "CREATE", "COPY", "DELETE", "DROP", "EXPLAIN", "FROM",
"FUNCTION", "INSERT", "JOIN", "ON", "PRAGMA", "SET", "SELECT", "TABLE", "UNION",
"VIEW", "WHERE", "WITH",
];
let last_word_upper = last_word.to_uppercase();
let candidates = KEYWORDS
.iter()
.filter(|command| command.starts_with(&last_word_upper))
.map(|command| rustyline::completion::Pair {
display: command.to_string(),
replacement: format!("{command} "),
})
.collect();
Ok((pos - last_word.len(), candidates))
}
}

#[cfg(test)]
mod tests {
use rustyline::history::DefaultHistory;

use super::*;

#[test]
fn test_completion() {
let db = Database::new_in_memory();
assert_complete(&db, "sel", "SELECT ");
assert_complete(&db, "sel|ect", "SELECT |ect");
assert_complete(&db, "select a f", "select a FROM ");
assert_complete(&db, "pragma en", "pragma enable_optimizer");
}

/// Assert that if complete (e.g. press tab) the given `line`, the result will be
/// `completed_line`.
///
/// Both `line` and `completed_line` can optionally contain a `|` which indicates the cursor
/// position. If not provided, the cursor is assumed to be at the end of the line.
#[track_caller]
fn assert_complete(db: &Database, line: &str, completed_line: &str) {
/// Find cursor position and remove it from the line.
fn get_line_and_cursor(line: &str) -> (String, usize) {
let (before_cursor, after_cursor) = line.split_once('|').unwrap_or((line, ""));
let pos = before_cursor.len();
(format!("{before_cursor}{after_cursor}"), pos)
}
let (mut line, pos) = get_line_and_cursor(line);

// complete
use rustyline::completion::Completer;
let (start_pos, candidates) = db
.complete(&line, pos, &rustyline::Context::new(&DefaultHistory::new()))
.unwrap();
let replacement = &candidates[0].replacement;
line.replace_range(start_pos..pos, replacement);

// assert
let (completed_line, completed_cursor_pos) = get_line_and_cursor(completed_line);
assert_eq!(line, completed_line);
assert_eq!(start_pos + replacement.len(), completed_cursor_pos);
}
}
8 changes: 5 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ use risinglight::storage::SecondaryStorageOptions;
use risinglight::utils::time::RoundingDuration;
use risinglight::Database;
use rustyline::error::ReadlineError;
use rustyline::DefaultEditor;
use rustyline::history::DefaultHistory;
use rustyline::Editor;
use sqllogictest::DefaultColumnType;
use tokio::{select, signal};
use tracing::{info, warn, Level};
Expand Down Expand Up @@ -149,7 +150,7 @@ async fn run_query_in_background(db: Arc<Database>, sql: String, output_format:
///
/// Note that `;` in string literals will also be treated as a terminator
/// as long as it is at the end of a line.
fn read_sql(rl: &mut DefaultEditor) -> Result<String, ReadlineError> {
fn read_sql(rl: &mut Editor<&Database, DefaultHistory>) -> Result<String, ReadlineError> {
let mut sql = String::new();
loop {
let prompt = if sql.is_empty() { "> " } else { "? " };
Expand All @@ -174,7 +175,7 @@ fn read_sql(rl: &mut DefaultEditor) -> Result<String, ReadlineError> {

/// Run RisingLight interactive mode
async fn interactive(db: Database, output_format: Option<String>) -> Result<()> {
let mut rl = DefaultEditor::new()?;
let mut rl = Editor::<&Database, DefaultHistory>::new()?;
let history_path = dirs::cache_dir().map(|p| {
let cache_dir = p.join("risinglight");
std::fs::create_dir_all(cache_dir.as_path()).ok();
Expand All @@ -192,6 +193,7 @@ async fn interactive(db: Database, output_format: Option<String>) -> Result<()>
}

let db = Arc::new(db);
rl.set_helper(Some(&db));

loop {
let read_sql = read_sql(&mut rl);
Expand Down

0 comments on commit a599403

Please sign in to comment.