Skip to content

Commit

Permalink
SafeSqlArgsCheck
Browse files Browse the repository at this point in the history
  • Loading branch information
fluentfuture committed Nov 29, 2023
1 parent da92ae4 commit a5b0650
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import static com.google.mu.util.Optionals.optionally;
import static com.google.mu.util.Substring.consecutive;
import static com.google.mu.util.Substring.first;
import static com.google.mu.util.Substring.firstOccurrence;
import static com.google.mu.util.Substring.BoundStyle.INCLUSIVE;

import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import com.google.common.base.Ascii;
import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.VisitorState;
Expand Down Expand Up @@ -63,6 +66,28 @@ static Optional<String> findFormatString(Tree unformatter, VisitorState state) {
.map(tree -> ASTHelpers.constValue(tree, String.class));
}

static boolean looksLikeSql(String template) {
return looksLikeQuery().or(looksLikeInsert()).in(Ascii.toLowerCase(template)).isPresent();
}

private static Substring.Pattern looksLikeQuery() {
return Stream.of("select", "update", "delete")
.map(w -> keyword(w))
.collect(firstOccurrence())
.peek(keyword("from").or(keyword("where")))
.peek(PLACEHOLDER_PATTERN);
}

private static Substring.Pattern looksLikeInsert() {
return keyword("insert into")
.peek(keyword("values").or(keyword("select")))
.peek(PLACEHOLDER_PATTERN);
}

private static Substring.Pattern keyword(String word) {
return first(word).separatedBy(CharMatcher.whitespace().or(CharMatcher.anyOf("()"))::matches);
}

private static List<? extends ExpressionTree> invocationArgs(Tree tree) {
if (tree instanceof NewClassTree) {
return ((NewClassTree) tree).getArguments();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package com.google.mu.errorprone;


import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.matchers.Matchers.anyMethod;

import com.google.auto.service.AutoService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.mu.util.Substring;
import com.google.errorprone.BugPattern;
import com.google.errorprone.BugPattern.LinkType;
import com.google.errorprone.VisitorState;
import com.google.errorprone.bugpatterns.BugChecker;
import com.google.errorprone.matchers.method.MethodMatchers.MethodClassMatcher;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.tools.javac.code.Symbol.MethodSymbol;
import com.sun.tools.javac.code.Type;

/**
* Warns against potential SQL injection risks caused by unquoted string placeholders. This class
* mainly checks that all string placeholder values must correspond to a quoted placeholder such as
* {@code '{foo}'}; and unquoted placeholders should only accept trusted types (enums, numbers,
* booleans, TrustedSqlString etc.)
*/
@BugPattern(
summary = "Checks that string placeholders in SQL template strings are quoted.",
link = "go/java-tips/024#preventing-sql-injection",
linkType = LinkType.CUSTOM,
severity = WARNING)
@AutoService(BugChecker.class)
public final class SafeSqlArgsCheck extends AbstractBugChecker
implements AbstractBugChecker.MethodInvocationCheck {
private static final MethodClassMatcher MATCHER =
anyMethod().onDescendantOf("com.google.mu.util.StringFormat.To");
private static final ImmutableSet<TypeName> ARG_TYPES_THAT_SHOULD_NOT_BE_QUOTED =
ImmutableSet.of(
new TypeName("com.google.storage.googlesql.safesql.TrustedSqlString"),
new TypeName("com.google.mu.safesql.SafeQuery"),
new TypeName("com.google.protobuf.Timestamp"));
private static final ImmutableSet<TypeName> ARG_TYPES_THAT_MUST_BE_QUOTED =
ImmutableSet.of(
TypeName.of(String.class), TypeName.of(Character.class), TypeName.of(char.class));

@Override
public void checkMethodInvocation(MethodInvocationTree tree, VisitorState state)
throws ErrorReport {
if (!MATCHER.matches(tree, state)) {
return;
}
MethodSymbol symbol = ASTHelpers.getSymbol(tree);
if (!symbol.isVarArgs() || symbol.getParameters().size() != 1) {
return;
}
ExpressionTree formatter = ASTHelpers.getReceiver(tree);
String formatString = FormatStringUtils.findFormatString(formatter, state).orElse(null);
if (formatString == null || !FormatStringUtils.looksLikeSql(formatString)) {
return;
}
ImmutableList<Substring.Match> placeholders =
FormatStringUtils
.PLACEHOLDER_PATTERN
.repeatedly()
.match(formatString)
.collect(toImmutableList());
if (placeholders.size() != tree.getArguments().size()) {
// Shouldn't happen. Will leave it to the other checks to report.
return;
}
for (int i = 0; i < placeholders.size(); i++) {
Substring.Match placeholder = placeholders.get(i);
ExpressionTree arg = tree.getArguments().get(i);
Type type = ASTHelpers.getType(arg);
if (placeholder.isImmediatelyBetween("'", "'")
|| placeholder.isImmediatelyBetween("\"", "\"")) {
// It's a quoted string literal. Do not use sql query types
checkingOn(arg)
.require(
ARG_TYPES_THAT_SHOULD_NOT_BE_QUOTED.stream()
.noneMatch(t -> t.isSameType(type, state)),
"argument of type %s should not be quoted: '%s'",
type,
placeholder);
} else if (!placeholder.isImmediatelyBetween("`", "`")) {
// Disallow arbitrary string literals or characters unless backquoted.
checkingOn(arg)
.require(
ARG_TYPES_THAT_MUST_BE_QUOTED.stream().noneMatch(t -> t.isSameType(type, state)),
"argument of type %s must be quoted (for example '%s')",
type,
placeholder);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
package com.google.mu.errorprone;

import com.google.errorprone.CompilationTestHelper;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class SafeSqlArgsCheckTest {
private final CompilationTestHelper helper =
CompilationTestHelper.newInstance(SafeSqlArgsCheck.class, getClass());

@Test
public void stringArgCanBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT '{v}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with(\"value\");",
" }",
"}")
.doTest();
}

@Test
public void stringArgCanBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT \\\"{v}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with(\"value\");",
" }",
"}")
.doTest();
}

@Test
public void charArgCanBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT '{v}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with('v');",
" }",
"}")
.doTest();
}

@Test
public void charArgCanBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT \\\"{v}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with('v');",
" }",
"}")
.doTest();
}

@Test
public void stringArgMustBeQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" // BUG: Diagnostic contains: java.lang.String must be quoted",
" return SELECT.with(\"column\");",
" }",
"}")
.doTest();
}

@Test
public void stringArgCanBeBackquoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT `{c}` FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with(\"column\");",
" }",
"}")
.doTest();
}

@Test
public void charArgMustBeQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" // BUG: Diagnostic contains: char must be quoted (for example '{c}')",
" return SELECT.with('x');",
" }",
"}")
.doTest();
}

@Test
public void safeQueryCannotBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"class Sql {",
" private static final StringFormat.To<Sql> UPDATE =",
" StringFormat.to(Sql::new, \"UPDATE '{c}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" // BUG: Diagnostic contains: SafeQuery should not be quoted: '{c}'",
" return UPDATE.with(SafeQuery.of(\"foo\"));",
" }",
"}")
.doTest();
}

@Test
public void safeQueryNotQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" return SELECT.with(SafeQuery.of(\"column\"));",
" }",
"}")
.doTest();
}

@Test
public void safeQueryCannotBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"class Sql {",
" private static final StringFormat.To<Sql> UPDATE =",
" StringFormat.to(Sql::new, \"UPDATE \\\"{c}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" // BUG: Diagnostic contains: SafeQuery should not be quoted: '{c}'",
" return UPDATE.with(SafeQuery.of(\"foo\"));",
" }",
"}")
.doTest();
}

@Test
public void formatStringNotFound() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" Sql test(StringFormat.To<Sql> select) {",
" return select.with('x');",
" }",
"}")
.doTest();
}

@Test
public void nonSql_notChecked() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class NotSql {",
" private static final StringFormat.To<NotSql> UPDATE =",
" StringFormat.to(NotSql::new, \"update {c}\");",
" NotSql(String query) {} ",
" NotSql test() {",
" return UPDATE.with('x');",
" }",
"}")
.doTest();
}
}
2 changes: 0 additions & 2 deletions mug/src/main/java/com/google/mu/util/Substring.java
Original file line number Diff line number Diff line change
Expand Up @@ -1251,9 +1251,7 @@ public final Pattern peek(String following) {
* a character.
*
* @since 6.0
* @deprecated Use {@link #followedBy} instead.
*/
@Deprecated
public Pattern peek(Pattern following) {
requireNonNull(following);
Pattern base = this;
Expand Down

0 comments on commit a5b0650

Please sign in to comment.