diff --git a/benchmarks/src/jmh/java/build/buf/protovalidate/benchmarks/ValidationBenchmark.java b/benchmarks/src/jmh/java/build/buf/protovalidate/benchmarks/ValidationBenchmark.java index b0ffeb86..d058b20b 100644 --- a/benchmarks/src/jmh/java/build/buf/protovalidate/benchmarks/ValidationBenchmark.java +++ b/benchmarks/src/jmh/java/build/buf/protovalidate/benchmarks/ValidationBenchmark.java @@ -17,6 +17,7 @@ import build.buf.protovalidate.Validator; import build.buf.protovalidate.ValidatorFactory; import build.buf.protovalidate.benchmarks.gen.ManyUnruledFieldsMessage; +import build.buf.protovalidate.benchmarks.gen.RegexPatternMessage; import build.buf.protovalidate.benchmarks.gen.RepeatedRuleMessage; import build.buf.protovalidate.benchmarks.gen.SimpleStringMessage; import build.buf.protovalidate.exceptions.ValidationException; @@ -40,6 +41,7 @@ public class ValidationBenchmark { private SimpleStringMessage simple; private ManyUnruledFieldsMessage manyUnruled; private RepeatedRuleMessage repeatedRule; + private RegexPatternMessage regexPattern; @Setup public void setup() throws ValidationException { @@ -67,10 +69,13 @@ public void setup() throws ValidationException { } repeatedRule = repeatedRuleBuilder.build(); + regexPattern = RegexPatternMessage.newBuilder().setName("Alice Example").build(); + // Warm evaluator cache for steady-state benchmarks. validator.validate(simple); validator.validate(manyUnruled); validator.validate(repeatedRule); + validator.validate(regexPattern); } // Steady-state validate() benchmarks. These exercise the hot path after the @@ -90,4 +95,9 @@ public void validateManyUnruled(Blackhole bh) throws ValidationException { public void validateRepeatedRule(Blackhole bh) throws ValidationException { bh.consume(validator.validate(repeatedRule)); } + + @Benchmark + public void validateRegexPattern(Blackhole bh) throws ValidationException { + bh.consume(validator.validate(regexPattern)); + } } diff --git a/benchmarks/src/jmh/proto/bench/v1/bench.proto b/benchmarks/src/jmh/proto/bench/v1/bench.proto index a1e47c2e..d6fde070 100644 --- a/benchmarks/src/jmh/proto/bench/v1/bench.proto +++ b/benchmarks/src/jmh/proto/bench/v1/bench.proto @@ -69,3 +69,13 @@ message RepeatedRuleMessage { string f19 = 19 [(buf.validate.field).string.min_len = 1]; string f20 = 20 [(buf.validate.field).string.min_len = 1]; } + +// Single string field with a string.pattern rule. Targets the regex +// recompile-per-evaluation cost: the CEL runtime's matches() calls +// Pattern.compile on every invocation. +message RegexPatternMessage { + string name = 1 [(buf.validate.field).string = { + pattern: "^[[:alpha:]]+( [[:alpha:]]+)*$" + max_bytes: 256 + }]; +} diff --git a/build.gradle.kts b/build.gradle.kts index eb07a152..04472205 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -398,6 +398,7 @@ dependencies { api(libs.jspecify) api(libs.protobuf.java) implementation(libs.cel) + implementation(libs.re2j) buf("build.buf:buf:${libs.versions.buf.get()}:${osdetector.classifier}@exe") diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index ff7c394e..bb1ea4c6 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -6,6 +6,7 @@ error-prone = "2.49.0" junit = "5.14.3" maven-publish = "0.36.0" protobuf = "4.34.1" +re2j = "1.8" [libraries] assertj = { module = "org.assertj:assertj-core", version.ref = "assertj" } @@ -19,6 +20,7 @@ junit-bom = { module = "org.junit:junit-bom", version.ref = "junit" } maven-plugin = { module = "com.vanniktech:gradle-maven-publish-plugin", version.ref = "maven-publish" } nullaway = { module = "com.uber.nullaway:nullaway", version = "0.13.3" } protobuf-java = { module = "com.google.protobuf:protobuf-java", version.ref = "protobuf" } +re2j = { module = "com.google.re2j:re2j", version.ref = "re2j" } spotless = { module = "com.diffplug.spotless:spotless-plugin-gradle", version = "8.4.0" } [plugins] diff --git a/src/main/java/build/buf/protovalidate/CustomDeclarations.java b/src/main/java/build/buf/protovalidate/CustomDeclarations.java index f1600477..a4f61bf8 100644 --- a/src/main/java/build/buf/protovalidate/CustomDeclarations.java +++ b/src/main/java/build/buf/protovalidate/CustomDeclarations.java @@ -104,6 +104,17 @@ static List create() { newMemberOverload( "is_hostname", SimpleType.BOOL, Collections.singletonList(SimpleType.STRING)))); + // Redeclare 'matches' with the same overload ids as the stdlib. + decls.add( + newFunctionDeclaration( + "matches", + newGlobalOverload( + "matches", SimpleType.BOOL, Arrays.asList(SimpleType.STRING, SimpleType.STRING)), + newMemberOverload( + "matches_string", + SimpleType.BOOL, + Arrays.asList(SimpleType.STRING, SimpleType.STRING)))); + decls.add( newFunctionDeclaration( "isHostAndPort", diff --git a/src/main/java/build/buf/protovalidate/CustomOverload.java b/src/main/java/build/buf/protovalidate/CustomOverload.java index 39a276d6..dd631eed 100644 --- a/src/main/java/build/buf/protovalidate/CustomOverload.java +++ b/src/main/java/build/buf/protovalidate/CustomOverload.java @@ -16,6 +16,10 @@ import com.google.protobuf.Descriptors; import com.google.protobuf.Message; +import com.google.re2j.Matcher; +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; +import dev.cel.common.CelOptions; import dev.cel.common.types.CelType; import dev.cel.common.types.SimpleType; import dev.cel.common.values.CelByteString; @@ -28,7 +32,7 @@ import java.util.List; import java.util.Locale; import java.util.Set; -import java.util.regex.Pattern; +import java.util.concurrent.ConcurrentMap; /** Defines custom function overloads (the implementation). */ final class CustomOverload { @@ -39,11 +43,14 @@ final class CustomOverload { "^[a-zA-Z0-9.!#$%&'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$"); /** - * Create custom function overload list. + * Create a list of custom function overloads. * + * @param patternCache cache used by the {@code matches}/{@code matches_string} overrides. + * @param celOptions CEL options the enclosing runtime is built with. * @return a list of overloaded functions. */ - static List create() { + static List create( + ConcurrentMap patternCache, CelOptions celOptions) { ArrayList bindings = new ArrayList<>(); bindings.addAll( Arrays.asList( @@ -65,7 +72,9 @@ static List create() { celIsNan(), celIsInfUnary(), celIsInfBinary(), - celIsHostAndPort())); + celIsHostAndPort(), + celMatches(patternCache, celOptions), + celMatchesString(patternCache, celOptions))); bindings.addAll(celUnique()); return Collections.unmodifiableList(bindings); } @@ -356,6 +365,41 @@ private static CelFunctionBinding celIsHostAndPort() { CustomOverload::isHostAndPort); } + /** Caching replacement for CEL's global {@code matches(string, string)}. */ + @SuppressWarnings("Immutable") + private static CelFunctionBinding celMatches( + ConcurrentMap patternCache, CelOptions celOptions) { + return CelFunctionBinding.from( + "matches", + String.class, + String.class, + (value, regex) -> matches(patternCache, celOptions, value, regex)); + } + + /** Caching replacement for CEL's member-style {@code string.matches(string)}. */ + @SuppressWarnings("Immutable") + private static CelFunctionBinding celMatchesString( + ConcurrentMap patternCache, CelOptions celOptions) { + return CelFunctionBinding.from( + "matches_string", + String.class, + String.class, + (value, regex) -> matches(patternCache, celOptions, value, regex)); + } + + private static boolean matches( + ConcurrentMap cache, CelOptions celOptions, String value, String regex) + throws CelEvaluationException { + Pattern pattern; + try { + pattern = cache.computeIfAbsent(regex, Pattern::compile); + } catch (PatternSyntaxException e) { + throw new CelEvaluationException("failed to compile regex: " + e.getMessage(), e); + } + Matcher matcher = pattern.matcher(value); + return celOptions.enableRegexPartialMatch() ? matcher.find() : matcher.matches(); + } + /** * Returns true if the string is a valid host/port pair, for example "example.com:8080". * diff --git a/src/main/java/build/buf/protovalidate/ValidateLibrary.java b/src/main/java/build/buf/protovalidate/ValidateLibrary.java index bbe2074d..0300f606 100644 --- a/src/main/java/build/buf/protovalidate/ValidateLibrary.java +++ b/src/main/java/build/buf/protovalidate/ValidateLibrary.java @@ -14,14 +14,23 @@ package build.buf.protovalidate; +import com.google.re2j.Pattern; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; import dev.cel.checker.CelCheckerBuilder; +import dev.cel.checker.CelStandardDeclarations; +import dev.cel.common.CelOptions; import dev.cel.common.CelVarDecl; import dev.cel.common.types.SimpleType; import dev.cel.compiler.CelCompilerLibrary; +import dev.cel.extensions.CelExtensions; import dev.cel.parser.CelParserBuilder; import dev.cel.parser.CelStandardMacro; import dev.cel.runtime.CelRuntimeBuilder; import dev.cel.runtime.CelRuntimeLibrary; +import dev.cel.runtime.CelStandardFunctions; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; /** * Custom {@link CelCompilerLibrary} and {@link CelRuntimeLibrary}. Provides all the custom @@ -29,9 +38,35 @@ */ final class ValidateLibrary implements CelCompilerLibrary, CelRuntimeLibrary { + private static final CelOptions CEL_OPTIONS = CelOptions.DEFAULT; + + private final ConcurrentMap patternCache = new ConcurrentHashMap<>(); + /** Creates a ValidateLibrary with all custom declarations and overloads. */ ValidateLibrary() {} + static Cel newCel() { + ValidateLibrary validateLibrary = new ValidateLibrary(); + // NOTE: CelExtensions.strings() does not implement string.reverse() or strings.quote() which + // are available in protovalidate-go. Fixed in https://github.com/google/cel-java/pull/998. + return CelFactory.standardCelBuilder() + .setOptions(CEL_OPTIONS) + // Drop stdlib matches; CustomOverload provides a caching replacement. + // Ref: https://github.com/google/cel-java/issues/1038 + .setStandardEnvironmentEnabled(false) + .setStandardDeclarations( + CelStandardDeclarations.newBuilder() + .excludeFunctions(CelStandardDeclarations.StandardFunction.MATCHES) + .build()) + .setStandardFunctions( + CelStandardFunctions.newBuilder() + .excludeFunctions(CelStandardFunctions.StandardFunction.MATCHES) + .build()) + .addCompilerLibraries(validateLibrary, CelExtensions.strings()) + .addRuntimeLibraries(validateLibrary, CelExtensions.strings()) + .build(); + } + @Override public void setParserOptions(CelParserBuilder parserBuilder) { parserBuilder.setStandardMacros( @@ -54,6 +89,6 @@ public void setCheckerOptions(CelCheckerBuilder checkerBuilder) { @Override public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) { - runtimeBuilder.addFunctionBindings(CustomOverload.create()); + runtimeBuilder.addFunctionBindings(CustomOverload.create(patternCache, CEL_OPTIONS)); } } diff --git a/src/main/java/build/buf/protovalidate/ValidatorImpl.java b/src/main/java/build/buf/protovalidate/ValidatorImpl.java index 9d749a91..39613c9f 100644 --- a/src/main/java/build/buf/protovalidate/ValidatorImpl.java +++ b/src/main/java/build/buf/protovalidate/ValidatorImpl.java @@ -18,10 +18,6 @@ import build.buf.protovalidate.exceptions.ValidationException; import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Message; -import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; -import dev.cel.common.CelOptions; -import dev.cel.extensions.CelExtensions; import java.util.ArrayList; import java.util.List; @@ -36,13 +32,14 @@ final class ValidatorImpl implements Validator { private final boolean failFast; ValidatorImpl(Config config) { - this.evaluatorBuilder = new EvaluatorBuilder(newCel(), config); + this.evaluatorBuilder = new EvaluatorBuilder(ValidateLibrary.newCel(), config); this.failFast = config.isFailFast(); } ValidatorImpl(Config config, List descriptors, boolean disableLazy) throws CompilationException { - this.evaluatorBuilder = new EvaluatorBuilder(newCel(), config, descriptors, disableLazy); + this.evaluatorBuilder = + new EvaluatorBuilder(ValidateLibrary.newCel(), config, descriptors, disableLazy); this.failFast = config.isFailFast(); } @@ -63,16 +60,4 @@ public ValidationResult validate(Message msg) throws ValidationException { } return new ValidationResult(violations); } - - private static Cel newCel() { - ValidateLibrary validateLibrary = new ValidateLibrary(); - // NOTE: CelExtensions.strings() does not implement string.reverse() or strings.quote() which - // are available in protovalidate-go. - return CelFactory.standardCelBuilder() - .addCompilerLibraries(validateLibrary, CelExtensions.strings()) - .addRuntimeLibraries(validateLibrary, CelExtensions.strings()) - .setOptions( - CelOptions.DEFAULT.toBuilder().evaluateCanonicalTypesToNativeValues(true).build()) - .build(); - } } diff --git a/src/test/java/build/buf/protovalidate/CustomOverloadTest.java b/src/test/java/build/buf/protovalidate/CustomOverloadTest.java index a7a6fcb9..73658dc0 100644 --- a/src/test/java/build/buf/protovalidate/CustomOverloadTest.java +++ b/src/test/java/build/buf/protovalidate/CustomOverloadTest.java @@ -18,9 +18,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import dev.cel.bundle.Cel; -import dev.cel.bundle.CelFactory; import dev.cel.common.CelAbstractSyntaxTree; -import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; import dev.cel.runtime.CelEvaluationException; @@ -31,14 +29,7 @@ public class CustomOverloadTest { - private final ValidateLibrary validateLibrary = new ValidateLibrary(); - private final Cel cel = - CelFactory.standardCelBuilder() - .addCompilerLibraries(validateLibrary) - .addRuntimeLibraries(validateLibrary) - .setOptions( - CelOptions.DEFAULT.toBuilder().evaluateCanonicalTypesToNativeValues(true).build()) - .build(); + private final Cel cel = ValidateLibrary.newCel(); @Test public void testIsInf() throws Exception { @@ -173,6 +164,19 @@ public void testBytesContains() throws Exception { assertThat(evalToBool("bytes('12345').contains(bytes('123456'))")).isFalse(); } + @Test + public void testMatchesPartialMatch() throws Exception { + // CelOptions.DEFAULT sets enableRegexPartialMatch(true), so an unanchored regex should + // match anywhere in the input (find()), not require a full-string match. + assertThat(evalToBool("'hello world'.matches('world')")).isTrue(); + assertThat(evalToBool("'hello world'.matches('ell')")).isTrue(); + // Anchored patterns still behave the same. + assertThat(evalToBool("'hello'.matches('^hello$')")).isTrue(); + assertThat(evalToBool("'hello world'.matches('^hello$')")).isFalse(); + // Global form. + assertThat(evalToBool("matches('hello world', 'world')")).isTrue(); + } + private Object eval(String source) throws Exception { return eval(source, Collections.emptyMap()); } diff --git a/src/test/java/build/buf/protovalidate/FormatTest.java b/src/test/java/build/buf/protovalidate/FormatTest.java index 3e9f460f..c167e03a 100644 --- a/src/test/java/build/buf/protovalidate/FormatTest.java +++ b/src/test/java/build/buf/protovalidate/FormatTest.java @@ -26,8 +26,6 @@ import com.google.protobuf.TextFormat; import dev.cel.bundle.Cel; import dev.cel.bundle.CelBuilder; -import dev.cel.bundle.CelFactory; -import dev.cel.common.CelOptions; import dev.cel.common.CelValidationException; import dev.cel.common.CelValidationResult; import dev.cel.common.types.SimpleType; @@ -87,14 +85,7 @@ public static void setUp() throws Exception { .flatMap(s -> s.getTestList().stream()) .collect(Collectors.toList()); - ValidateLibrary validateLibrary = new ValidateLibrary(); - cel = - CelFactory.standardCelBuilder() - .addCompilerLibraries(validateLibrary) - .addRuntimeLibraries(validateLibrary) - .setOptions( - CelOptions.DEFAULT.toBuilder().evaluateCanonicalTypesToNativeValues(true).build()) - .build(); + cel = ValidateLibrary.newCel(); } @ParameterizedTest