Skip to content

Commit

Permalink
Make SafeQueryArgsCheck ERROR instead of WARNING
Browse files Browse the repository at this point in the history
  • Loading branch information
fluentfuture committed Dec 1, 2023
1 parent 2f407aa commit e9127d1
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
import static com.google.errorprone.BugPattern.SeverityLevel.ERROR;
import static com.google.errorprone.matchers.Matchers.anyMethod;

import com.google.auto.service.AutoService;
Expand Down Expand Up @@ -30,12 +30,14 @@
summary = "Checks that string placeholders in SQL template strings are quoted.",
link = "go/java-tips/024#preventing-sql-injection",
linkType = LinkType.CUSTOM,
severity = WARNING)
severity = ERROR)
@AutoService(BugChecker.class)
public final class SafeSqlArgsCheck extends AbstractBugChecker
public final class SafeQueryArgsCheck extends AbstractBugChecker
implements AbstractBugChecker.MethodInvocationCheck {
private static final MethodClassMatcher MATCHER =
anyMethod().onDescendantOf("com.google.mu.util.StringFormat.To");
private static final TypeName SAFE_QUERY_TYPE =
new TypeName("com.google.mu.safesql.SafeQuery");
private static final ImmutableSet<TypeName> ARG_TYPES_THAT_SHOULD_NOT_BE_QUOTED =
ImmutableSet.of(
new TypeName("com.google.storage.googlesql.safesql.TrustedSqlString"),
Expand All @@ -51,6 +53,9 @@ public void checkMethodInvocation(MethodInvocationTree tree, VisitorState state)
if (!MATCHER.matches(tree, state)) {
return;
}
if (!SAFE_QUERY_TYPE.isSameType(ASTHelpers.getType(tree), state)) {
return;
}
MethodSymbol symbol = ASTHelpers.getSymbol(tree);
if (!symbol.isVarArgs() || symbol.getParameters().size() != 1) {
return;
Expand All @@ -72,6 +77,9 @@ public void checkMethodInvocation(MethodInvocationTree tree, VisitorState state)
}
for (int i = 0; i < placeholders.size(); i++) {
Substring.Match placeholder = placeholders.get(i);
if (placeholder.isImmediatelyBetween("`", "`")) {
continue;
}
ExpressionTree arg = tree.getArguments().get(i);
Type type = ASTHelpers.getType(arg);
if (placeholder.isImmediatelyBetween("'", "'")
Expand All @@ -84,14 +92,15 @@ public void checkMethodInvocation(MethodInvocationTree tree, VisitorState state)
"argument of type %s should not be quoted: '%s'",
type,
placeholder);
} else if (!placeholder.isImmediatelyBetween("`", "`")) {
// Disallow arbitrary string literals or characters unless backquoted.
} else { // Disallow arbitrary string literals or characters unless backquoted.
checkingOn(arg)
.require(
ARG_TYPES_THAT_MUST_BE_QUOTED.stream().noneMatch(t -> t.isSameType(type, state)),
"argument of type %s must be quoted (for example '%s')",
type,
placeholder);
.require(
ARG_TYPES_THAT_MUST_BE_QUOTED.stream().noneMatch(t -> t.isSameType(type, state)),
"argument of type %s must be quoted (for example '%s' for string literals or `%s`"
+ " for identifiers)",
type,
placeholder,
placeholder);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class SafeSqlArgsCheckTest {
public final class SafeQueryArgsCheckTest {
private final CompilationTestHelper helper =
CompilationTestHelper.newInstance(SafeSqlArgsCheck.class, getClass());
CompilationTestHelper.newInstance(SafeQueryArgsCheck.class, getClass());

@Test
public void stringArgCanBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT '{v}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT '{v}' FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with(\"value\");",
" }",
"}")
Expand All @@ -32,12 +32,12 @@ public void stringArgCanBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT \\\"{v}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT \\\"{v}\\\" FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with(\"value\");",
" }",
"}")
Expand All @@ -49,12 +49,12 @@ public void charArgCanBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT '{v}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT '{v}' FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with('v');",
" }",
"}")
Expand All @@ -66,12 +66,12 @@ public void charArgCanBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT \\\"{v}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT \\\"{v}\\\" FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with('v');",
" }",
"}")
Expand All @@ -83,12 +83,12 @@ public void stringArgMustBeQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT {c} FROM tbl\");",
" SafeQuery test() {",
" // BUG: Diagnostic contains: java.lang.String must be quoted",
" return SELECT.with(\"column\");",
" }",
Expand All @@ -101,12 +101,12 @@ public void stringArgCanBeBackquoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT `{c}` FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT `{c}` FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with(\"column\");",
" }",
"}")
Expand All @@ -118,13 +118,13 @@ public void charArgMustBeQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" // BUG: Diagnostic contains: char must be quoted (for example '{c}')",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT {c} FROM tbl\");",
" SafeQuery test() {",
" // BUG: Diagnostic contains: char must be quoted (for example '{c}'",
" return SELECT.with('x');",
" }",
"}")
Expand All @@ -136,13 +136,12 @@ public void safeQueryCannotBeSingleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> UPDATE =",
" StringFormat.to(Sql::new, \"UPDATE '{c}' FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> UPDATE =",
" SafeQuery.template(\"UPDATE '{c}' FROM tbl\");",
" SafeQuery test() {",
" // BUG: Diagnostic contains: SafeQuery should not be quoted: '{c}'",
" return UPDATE.with(SafeQuery.of(\"foo\"));",
" }",
Expand All @@ -155,13 +154,12 @@ public void safeQueryNotQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> SELECT =",
" StringFormat.to(Sql::new, \"SELECT {c} FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> SELECT =",
" SafeQuery.template(\"SELECT {c} FROM tbl\");",
" SafeQuery test() {",
" return SELECT.with(SafeQuery.of(\"column\"));",
" }",
"}")
Expand All @@ -173,13 +171,12 @@ public void safeQueryCannotBeDoubleQuoted() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" private static final StringFormat.To<Sql> UPDATE =",
" StringFormat.to(Sql::new, \"UPDATE \\\"{c}\\\" FROM tbl\");",
" Sql(String query) {} ",
" Sql test() {",
" private static final StringFormat.To<SafeQuery> UPDATE =",
" SafeQuery.template(\"UPDATE \\\"{c}\\\" FROM tbl\");",
" SafeQuery test() {",
" // BUG: Diagnostic contains: SafeQuery should not be quoted: '{c}'",
" return UPDATE.with(SafeQuery.of(\"foo\"));",
" }",
Expand All @@ -192,26 +189,27 @@ public void formatStringNotFound() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.safesql.SafeQuery;",
"import com.google.mu.util.StringFormat;",
"class Sql {",
" Sql test(StringFormat.To<Sql> select) {",
" SafeQuery test(StringFormat.To<SafeQuery> select) {",
" return select.with('x');",
" }",
"}")
.doTest();
}

@Test
public void nonSql_notChecked() {
public void nonSafeQuery_notChecked() {
helper
.addSourceLines(
"Test.java",
"import com.google.mu.util.StringFormat;",
"class NotSql {",
" private static final StringFormat.To<NotSql> UPDATE =",
" StringFormat.to(NotSql::new, \"update {c}\");",
" NotSql(String query) {} ",
" NotSql test() {",
"class NotSafeQuery {",
" private static final StringFormat.To<NotSafeQuery> UPDATE =",
" StringFormat.to(NotSafeQuery::new, \"update {c}\");",
" NotSafeQuery(String query) {} ",
" NotSafeQuery test() {",
" return UPDATE.with('x');",
" }",
"}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.google.mu.safesql;


import static com.google.common.truth.Truth.assertThat;
import static com.google.mu.safesql.SafeQuery.template;
import static java.util.Arrays.asList;
Expand Down

0 comments on commit e9127d1

Please sign in to comment.