commit 8bb2265abdaf6d565f69986d81b9541c6a1eee1f Author: wea_ondara Date: Mon Jul 11 21:18:54 2022 +0200 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ebf457 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +target/ +*.iml +.idea/ diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..ba7a7a9 --- /dev/null +++ b/pom.xml @@ -0,0 +1,137 @@ + + + 4.0.0 + + jef + jef + 1.0 + + + UTF-8 + 17 + 17 + 9.3 + 1.18.16 + + + + + java9plus + + [1.9,) + + + ${maven.compiler.target} + + + + java11plus + + [1.11,) + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.1 + + true + + -J--add-opens=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.comp=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.main=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED + -J--add-opens=jdk.compiler/com.sun.tools.javac.jvm=ALL-UNNAMED + + + + + org.projectlombok + lombok + ${lombok.version} + + + + + + + + + + ${project.artifactId} + + + src/main/resources + true + + + + + org.apache.maven.plugins + maven-help-plugin + 3.1.0 + + + show-profiles + validate + + active-profiles + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + + org.apache.maven.plugins + maven-resources-plugin + 2.7 + + ${project.build.sourceEncoding} + + + + + + + + org.ow2.asm + asm-tree + ${asm.version} + test + + + org.ow2.asm + asm + ${asm.version} + + + org.ow2.asm + asm-commons + ${asm.version} + + + org.junit.jupiter + junit-jupiter + 5.8.2 + test + + + org.projectlombok + lombok + ${lombok.version} + compile + + + \ No newline at end of file diff --git a/src/main/java/jef/DBSet.java b/src/main/java/jef/DBSet.java new file mode 100644 index 0000000..442cea4 --- /dev/null +++ b/src/main/java/jef/DBSet.java @@ -0,0 +1,229 @@ +package jef; + +import jef.expressions.Expression; +import jef.expressions.SelectExpression; +import jef.expressions.TableExpression; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.List; +import java.util.Spliterator; + +public class DBSet implements Queryable { + private final String table; + + public DBSet(String table) { + this.table = table; + } + + @Override + public String getTableAlias() { + return String.valueOf((char) ('a' - 1)); + } + + @Override + public Expression getExpression() { + return new SelectExpression(List.of("*"), new TableExpression(table), ""); + } + + @Override + public String toString() { + return "SELECT * FROM `" + table + "`"; + } + + @Override + public Iterator iterator() { + return null; + } + + @Override + public Spliterator spliterator() { + return null; + } + + @Override + public boolean isParallel() { + return false; + } + + @Override + public DBSet sequential() { + return this; + } + + @Override + public DBSet parallel() { + return this; + } + + @Override + public DBSet unordered() { + return this; + } + + @Override + public DBSet onClose(Runnable runnable) { + return this; + } + + @Override + public void close() { + + } + + //stream +// @Override +// public Queryable map(Function function) { +// return null; +// } +// +// @Override +// public IntStream mapToInt(ToIntFunction toIntFunction) { +// return null; +// } +// +// @Override +// public LongStream mapToLong(ToLongFunction toLongFunction) { +// return null; +// } +// +// @Override +// public DoubleStream mapToDouble(ToDoubleFunction toDoubleFunction) { +// return null; +// } +// +// @Override +// public Queryable flatMap(Function> function) { +// return null; +// } +// +// @Override +// public IntStream flatMapToInt(Function function) { +// return null; +// } +// +// @Override +// public LongStream flatMapToLong(Function function) { +// return null; +// } +// +// @Override +// public DoubleStream flatMapToDouble(Function function) { +// return null; +// } +// +// @Override +// public DBSet distinct() { +// return null; +// } +// +// @Override +// public DBSet sorted() { +// return null; +// } +// +// @Override +// public DBSet sorted(Comparator comparator) { +// return null; +// } +// +// @Override +// public DBSet peek(Consumer consumer) { +// return null; +// } +// +// @Override +// public DBSet limit(long l) { +// return null; +// } +// +// @Override +// public DBSet skip(long l) { +// return null; +// } + + // +// @Override +// public void forEach(Consumer consumer) { +// } +// +// @Override +// public void forEachOrdered(Consumer consumer) { +// } +// +// @Override +// public Object[] toArray() { +// return new Object[0]; +// } +// +// @Override +// public A[] toArray(IntFunction intFunction) { +// return null; +// } +// +// @Override +// public T reduce(T t, BinaryOperator binaryOperator) { +// return null; +// } +// +// @Override +// public Optional reduce(BinaryOperator binaryOperator) { +// return Optional.empty(); +// } +// +// @Override +// public U reduce(U u, BiFunction biFunction, BinaryOperator binaryOperator) { +// return null; +// } +// +// @Override +// public R collect(Supplier supplier, BiConsumer biConsumer, BiConsumer biConsumer1) { +// return null; +// } +// +// @Override +// public R collect(Collector collector) { +// return null; +// } +// +// @Override +// public Optional min(Comparator comparator) { +// return Optional.empty(); +// } +// +// @Override +// public Optional max(Comparator comparator) { +// return Optional.empty(); +// } +// +// @Override +// public long count() { +// return 0; +// } +// +// @Override +// public boolean anyMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// @Override +// public boolean allMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// @Override +// public boolean noneMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// @Override +// public Optional findFirst() { +// return Optional.empty(); +// } +// +// @Override +// public Optional findAny() { +// return Optional.empty(); +// } + // +} diff --git a/src/main/java/jef/Queryable.java b/src/main/java/jef/Queryable.java new file mode 100644 index 0000000..2b031f5 --- /dev/null +++ b/src/main/java/jef/Queryable.java @@ -0,0 +1,181 @@ +package jef; + +import jef.expressions.Expression; +import jef.operations.FilterOp; +import jef.serializable.SerializablePredicate; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.Spliterator; + +public interface Queryable { + + // + String getTableAlias(); + + Expression getExpression(); + + String toString(); + + //stream functions + default Iterator iterator() { + return null; + } + + default Spliterator spliterator() { + return null; + } + + default boolean isParallel() { + return false; + } + + default Queryable sequential() { + return this; + } + + default Queryable parallel() { + return this; + } + + default Queryable unordered() { + return this; + } + + default Queryable onClose(Runnable runnable) { + return this; + } + + default void close() { + + } + + //stream + default Queryable filter(SerializablePredicate predicate) { + return new FilterOp(this, predicate); + } + +// default Queryable map(Function function) { +// return null; +// } +// +// default IntStream mapToInt(ToIntFunction toIntFunction) { +// return null; +// } +// +// default LongStream mapToLong(ToLongFunction toLongFunction) { +// return null; +// } +// +// default DoubleStream mapToDouble(ToDoubleFunction toDoubleFunction) { +// return null; +// } +// +// default Queryable flatMap(Function> function) { +// return null; +// } +// +// default IntStream flatMapToInt(Function function) { +// return null; +// } +// +// default LongStream flatMapToLong(Function function) { +// return null; +// } +// +// default DoubleStream flatMapToDouble(Function function) { +// return null; +// } +// +// default Queryable distinct() { +// return null; +// } +// +// default Queryable sorted() { +// return null; +// } +// +// default Queryable sorted(Comparator comparator) { +// return null; +// } +// +// default Queryable peek(Consumer consumer) { +// return null; +// } +// +// default Queryable limit(long l) { +// return null; +// } +// +// default Queryable skip(long l) { +// return null; +// } + + // +// default void forEach(Consumer consumer) { +// } +// +// default void forEachOrdered(Consumer consumer) { +// } +// +// default Object[] toArray() { +// return new Object[0]; +// } +// +// default A[] toArray(IntFunction intFunction) { +// return null; +// } +// +// default T reduce(T t, BinaryOperator binaryOperator) { +// return null; +// } +// +// default Optional reduce(BinaryOperator binaryOperator) { +// return Optional.empty(); +// } +// +// default U reduce(U u, BiFunction biFunction, BinaryOperator binaryOperator) { +// return null; +// } +// +// default R collect(Supplier supplier, BiConsumer biConsumer, BiConsumer biConsumer1) { +// return null; +// } +// +// default R collect(Collector collector) { +// return null; +// } +// +// default Optional min(Comparator comparator) { +// return Optional.empty(); +// } +// +// default Optional max(Comparator comparator) { +// return Optional.empty(); +// } +// +// default long count() { +// return 0; +// } +// +// default boolean anyMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// default boolean allMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// default boolean noneMatch(SerializablePredicate SerializablePredicate) { +// return false; +// } +// +// default Optional findFirst() { +// return Optional.empty(); +// } +// +// default Optional findAny() { +// return Optional.empty(); +// } + // +} diff --git a/src/main/java/jef/QueryableProxy.java b/src/main/java/jef/QueryableProxy.java new file mode 100644 index 0000000..0331529 --- /dev/null +++ b/src/main/java/jef/QueryableProxy.java @@ -0,0 +1,229 @@ +package jef; + +import jef.expressions.Expression; +import jef.serializable.SerializablePredicate; + +import java.io.Serializable; +import java.util.Iterator; +import java.util.Spliterator; + +public class QueryableProxy implements Queryable { + private final Queryable delegate; + + public QueryableProxy(Queryable delegate) { + this.delegate = delegate; + } + + @Override + public String getTableAlias() { + return delegate.getTableAlias(); + } + + @Override + public Expression getExpression() { + return delegate.getExpression(); + } + + @Override + public Iterator iterator() { + return delegate.iterator(); + } + + @Override + public Spliterator spliterator() { + return delegate.spliterator(); + } + + @Override + public boolean isParallel() { + return delegate.isParallel(); + } + + @Override + public Queryable sequential() { + return delegate.sequential(); + } + + @Override + public Queryable parallel() { + return delegate.parallel(); + } + + @Override + public Queryable unordered() { + return delegate.unordered(); + } + + @Override + public Queryable onClose(Runnable runnable) { + return delegate.onClose(runnable); + } + + @Override + public void close() { + delegate.close(); + } + + //stream + @Override + public Queryable filter(SerializablePredicate predicate) { + return delegate.filter(predicate); + } + +// @Override +// public Queryable map(Function function) { +// return delegate.map(function); +// } +// +// @Override +// public IntStream mapToInt(ToIntFunction toIntFunction) { +// return delegate.mapToInt(toIntFunction); +// } +// +// @Override +// public LongStream mapToLong(ToLongFunction toLongFunction) { +// return delegate.mapToLong(toLongFunction); +// } +// +// @Override +// public DoubleStream mapToDouble(ToDoubleFunction toDoubleFunction) { +// return delegate.mapToDouble(toDoubleFunction); +// } +// +// @Override +// public Queryable flatMap(Function> function) { +// return delegate.flatMap(function); +// } +// +// @Override +// public IntStream flatMapToInt(Function function) { +// return delegate.flatMapToInt(function); +// } +// +// @Override +// public LongStream flatMapToLong(Function function) { +// return delegate.flatMapToLong(function); +// } +// +// @Override +// public DoubleStream flatMapToDouble(Function function) { +// return delegate.flatMapToDouble(function); +// } +// +// @Override +// public Queryable distinct() { +// return delegate.distinct(); +// } +// +// @Override +// public Queryable sorted() { +// return delegate.sorted(); +// } +// +// @Override +// public Queryable sorted(Comparator comparator) { +// return delegate.sorted(comparator); +// } +// +// @Override +// public Queryable peek(Consumer consumer) { +// return delegate.peek(consumer); +// } +// +// @Override +// public Queryable limit(long l) { +// return delegate.limit(l); +// } +// +// @Override +// public Queryable skip(long l) { +// return delegate.skip(l); +// } + + // +// @Override +// public void forEach(Consumer consumer) { +// delegate.forEach(consumer); +// } +// +// @Override +// public void forEachOrdered(Consumer consumer) { +// delegate.forEachOrdered(consumer); +// } +// +// @Override +// public Object[] toArray() { +// return delegate.toArray(); +// } +// +// @Override +// public A[] toArray(IntFunction intFunction) { +// return delegate.toArray(intFunction); +// } +// +// @Override +// public T reduce(T t, BinaryOperator binaryOperator) { +// return delegate.reduce(t, binaryOperator); +// } +// +// @Override +// public Optional reduce(BinaryOperator binaryOperator) { +// return delegate.reduce(binaryOperator); +// } +// +// @Override +// public U reduce(U u, BiFunction biFunction, BinaryOperator binaryOperator) { +// return delegate.reduce(u, biFunction, binaryOperator); +// } +// +// @Override +// public R collect(Supplier supplier, BiConsumer biConsumer, BiConsumer biConsumer1) { +// return delegate.collect(supplier, biConsumer, biConsumer1); +// } +// +// @Override +// public R collect(Collector collector) { +// return delegate.collect(collector); +// } +// +// @Override +// public Optional min(Comparator comparator) { +// return delegate.min(comparator); +// } +// +// @Override +// public Optional max(Comparator comparator) { +// return delegate.max(comparator); +// } +// +// @Override +// public long count() { +// return delegate.count(); +// } +// +// @Override +// public boolean anyMatch(SerializablePredicate SerializablePredicate) { +// return delegate.anyMatch(SerializablePredicate); +// } +// +// @Override +// public boolean allMatch(SerializablePredicate SerializablePredicate) { +// return delegate.allMatch(SerializablePredicate); +// } +// +// @Override +// public boolean noneMatch(SerializablePredicate SerializablePredicate) { +// return delegate.noneMatch(SerializablePredicate); +// } +// +// @Override +// public Optional findFirst() { +// return delegate.findFirst(); +// } +// +// @Override +// public Optional findAny() { +// return delegate.findAny(); +// } + // +} diff --git a/src/main/java/jef/asm/AsmParseException.java b/src/main/java/jef/asm/AsmParseException.java new file mode 100644 index 0000000..4d7cc65 --- /dev/null +++ b/src/main/java/jef/asm/AsmParseException.java @@ -0,0 +1,22 @@ +package jef.asm; + +public class AsmParseException extends Exception { + public AsmParseException() { + } + + public AsmParseException(String message) { + super(message); + } + + public AsmParseException(String message, Throwable cause) { + super(message, cause); + } + + public AsmParseException(Throwable cause) { + super(cause); + } + + public AsmParseException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } +} diff --git a/src/main/java/jef/asm/FilterClassVisitor.java b/src/main/java/jef/asm/FilterClassVisitor.java new file mode 100644 index 0000000..a5b5845 --- /dev/null +++ b/src/main/java/jef/asm/FilterClassVisitor.java @@ -0,0 +1,32 @@ +package jef.asm; + +import jef.expressions.Expression; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.MethodVisitor; + +import java.util.function.Consumer; + +class FilterClassVisitor extends ClassVisitor { + private final int api; + private final String lambdaname; + private final Consumer queryConsumer; + private final Object[] args; + + protected FilterClassVisitor(int api, String lambdaname, Object[] args, Consumer queryConsumer) { + super(api); + this.api = api; + this.lambdaname = lambdaname; + this.args = args; + this.queryConsumer = queryConsumer; + } + + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) { + if (!name.equals(lambdaname)) { + return super.visitMethod(access, name, descriptor, signature, exceptions); + } + + System.out.println("found method name: " + name); + return new FilterMethodVisitor(api, descriptor, args, queryConsumer); + } +} diff --git a/src/main/java/jef/asm/FilterMethodVisitor.java b/src/main/java/jef/asm/FilterMethodVisitor.java new file mode 100644 index 0000000..1f77ab8 --- /dev/null +++ b/src/main/java/jef/asm/FilterMethodVisitor.java @@ -0,0 +1,445 @@ +package jef.asm; + +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.FieldExpression; +import jef.expressions.ParameterExpression; +import jef.expressions.SelectExpression; +import jef.expressions.TableExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; +import jef.expressions.WhereExpression; +import lombok.ToString; +import org.objectweb.asm.Attribute; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.Stack; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +class FilterMethodVisitor extends MethodVisitor { + private Stack varStack = new Stack<>(); + private final String[] parameterClasses; + private final Object[] args; + private final Consumer exprConsumer; + private Expression prediacteExpr; + + protected FilterMethodVisitor(int api, String descriptor, Object[] args, Consumer exprConsumer) { + super(api); + this.args = args; + this.exprConsumer = exprConsumer; + + //parameters + var types = Type.getMethodType(descriptor).getArgumentTypes(); + parameterClasses = new String[types.length]; + for (int i = 0; i < types.length; i++) { + parameterClasses[i] = types[i].getClassName(); + } + } + + @Override + public void visitLocalVariable(String name, String descriptor, String signature, Label start, Label end, int index) { + System.out.println("local var: " + name); + super.visitLocalVariable(name, descriptor, signature, start, end, index); + debugExpr(); + } + + @Override + public void visitFieldInsn(int opcode, String owner, String name, String descriptor) { + System.out.println("field insn: " + ops.getOrDefault(opcode, "" + opcode) + ", " + owner + ", " + name + ", " + descriptor); + if (opcode == Opcodes.GETFIELD) { + var v = varStack.pop(); + if (v instanceof ParameterExpression) { +// if (((Expression.ParameterExpression) v).isInput()) { +// varStack.push(new Expression.ConstantExpression(name)); +// } else { + System.out.println("womp womp " + v); + throw new RuntimeException("field insn: unsupported GETFIELD op"); +// } + } else if (v instanceof ConstantExpression) { + varStack.push(new FieldExpression(name)); + } else { + throw new RuntimeException("field insn: unsupported GETFIELD op"); + } + } else { + throw new RuntimeException("field insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + super.visitFieldInsn(opcode, owner, name, descriptor); + debugExpr(); + } + + @Override + public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) { + System.out.println("method insn: " + ops.getOrDefault(opcode, "" + opcode) + ", " + owner + ", " + name + ", " + descriptor); + if (opcode == Opcodes.INVOKESTATIC) { + //ignore boxed primitive types + var boxedPrimitiveClasses = Set.of("java/lang/Boolean", "java/lang/Integer", "java/lang/Long", "java/lang/Float", "java/lang/Double"); + if (boxedPrimitiveClasses.contains(owner)) { + //do nothing + } else { + //do something + throw new RuntimeException("method insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + } else if (opcode == Opcodes.INVOKEINTERFACE) { + try { + if (name.equals("contains") + && owner.startsWith("java/util/") + && Collection.class.isAssignableFrom(Class.forName(owner.replace("/", ".")))) { + var element = varStack.pop(); + var collection = varStack.pop(); +// System.out.println("element: " + element); +// System.out.println("collection: " + collection); + varStack.push(new BinaryExpression(element, collection, BinaryExpression.Operator.IN)); + } + } catch (ClassNotFoundException e) { + throw new RuntimeException("method insn: ", e); + } + } else { + throw new RuntimeException("method insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + super.visitMethodInsn(opcode, owner, name, descriptor, isInterface); + debugExpr(); + } + + @Override + public void visitVarInsn(int opcode, int varIndex) { + System.out.println("var insn: " + ops.getOrDefault(opcode, "" + opcode) + ", " + varIndex); + if (opcode == Opcodes.ALOAD) { + if (varIndex == parameterClasses.length - 1) { + varStack.push(new ConstantExpression("predicate param")); + } else { + varStack.push(new ParameterExpression(varIndex, args[varIndex], varIndex == parameterClasses.length - 1)); + } + } else { + throw new RuntimeException("var insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + super.visitVarInsn(opcode, varIndex); + debugExpr(); + } + + @Override + public void visitInsn(int opcode) { + System.out.println("insn: " + ops.getOrDefault(opcode, "" + opcode)); + if (opcode == Opcodes.ICONST_1) { + varStack.push(ConstantExpression.I1); + } else if (opcode == Opcodes.ICONST_0) { + varStack.push(ConstantExpression.I0); + } else if (opcode == Opcodes.IRETURN) { + //collapse conditions + for (int i = condStack.size() - 1; i >= 0; i--) { + condStack.get(i).e1 = varStack.pop(); + varStack = condStack.get(i).varStack; + evalCond(condStack.get(i)); + } +// condStack.clear(); + prediacteExpr = varStack.pop(); + } else { + throw new RuntimeException("insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + +// if (!mpgotoconds.isEmpty()) { +// var e1 = varStack.pop(); +// for (int i = mpgotoconds.size() - 1; i >= 0; i--) { +// var cond = mpgotoconds.get(i); +// cond.e1 = e1; +//// for (int j = 0; j < condStack.size() && condStack.get(j) != cond; j++) { +//// condStack.get(j).e1 = e1; +//// } +// condStack.remove(cond); +// varStack = cond.varStack; +// evalCond(cond); +// } +// mpgotoconds = new ArrayList<>(); +// } + + super.visitInsn(opcode); + debugExpr(); + } + + @ToString + class Cond { + final int opcode; + final Label condTarget; + Label gotoTarget; + final Stack varStack; + Expression e1; + Expression e2; + + public Cond(int opcode, Label condTarget, Stack varStack) { + this.opcode = opcode; + this.condTarget = condTarget; + this.varStack = varStack; + } + } + + private LinkedList condStack = new LinkedList<>(); + + @Override + public void visitJumpInsn(int opcode, Label label) { + System.out.println("jump insn: " + ops.getOrDefault(opcode, "" + opcode) + ", " + label); + + switch (opcode) { + case Opcodes.IFEQ: + case Opcodes.IFNE: + case Opcodes.IF_ICMPNE: + case Opcodes.IF_ICMPEQ: + case Opcodes.IF_ICMPGE: + case Opcodes.IF_ICMPLE: + case Opcodes.IF_ICMPGT: + case Opcodes.IF_ICMPLT: + handleLextLabel(); + condStack.add(new Cond(opcode, label, varStack)); + varStack = new Stack<>(); + System.out.println("jump new stack"); + break; + case Opcodes.GOTO: + handleLextLabel(); + System.out.println("goto new stack"); + condStack.getLast().gotoTarget = label; + condStack.getLast().e1 = varStack.pop(); + varStack = new Stack<>(); + break; + default: { + throw new RuntimeException("jump insn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + } + +// int IFEQ = 153; // visitJumpInsn +// int IFNE = 154; // - +// int IFLT = 155; // - +// int IFGE = 156; // - +// int IFGT = 157; // - +// int IFLE = 158; // - +// int IF_ICMPEQ = 159; // - +// int IF_ICMPNE = 160; // - +// int IF_ICMPLT = 161; // - +// int IF_ICMPGE = 162; // - +// int IF_ICMPGT = 163; // - +// int IF_ICMPLE = 164; // - +// int IF_ACMPEQ = 165; // - +// int IF_ACMPNE = 166; // - +// int GOTO = 167; // - +// int JSR = 168; // - +// int RET = 169; // visitVarInsn + + super.visitJumpInsn(opcode, label); + debugExpr(); + } + + List nextLabelConds = new ArrayList<>(); + + private void handleLextLabel() { + if (!nextLabelConds.isEmpty()) { + var e1 = varStack.peek(); + for (int i = 0; i < nextLabelConds.size(); i++) { +// for (int i = nextLabelConds.size() - 1; i >= 0; i--) { + var cond = nextLabelConds.get(i); + cond.e2 = e1; +// for (int j = 0; j < condStack.size() && condStack.get(j) != cond; j++) { +// condStack.get(j).e1 = e1; +// } +// condStack.remove(cond); +// varStack = cond.varStack; +// evalCond(cond); + } + nextLabelConds = new ArrayList<>(); + } + } + + @Override + public void visitLabel(Label label) { + System.out.println("label: " + label); + handleLextLabel(); + +// System.out.println("------>"); +// System.out.println("condstack: " + condStack); + List conds = condStack.stream().filter(e -> e.gotoTarget != null && e.gotoTarget.equals(label)).toList(); + nextLabelConds = condStack.stream().filter(e -> e.gotoTarget == null && e.condTarget.equals(label)).toList(); + + // List conds = condStack.stream().filter(e -> e.gotoTarget != null +// ? e.gotoTarget.equals(label) +// : e.condTarget.equals(label)//bad +// ).toList(); + if (!conds.isEmpty()) { + Expression e2 = varStack.pop(); + for (int i = conds.size() - 1; i >= 0; i--) { + var cond = conds.get(i); +// System.out.println("condstack pop " + cond); + cond.e2 = e2; + //also set e2 for all outer conds +// for (int j = 0; j < condStack.size() && condStack.get(j) != cond; j++) { +// condStack.get(j).e2 = e2; +// } + condStack.remove(cond); + varStack = cond.varStack; + evalCond(cond); + } + } +// System.out.println("<------"); + super.visitLabel(label); + debugExpr(); + } + + Label collapseAfterNextInstruction = null; + + private void evalCond(Cond cond) { + var right = varStack.pop(); + +// System.out.println("left: " + left); +// System.out.println("cond.e1: " + cond.e1); +// System.out.println("cond.e2: " + cond.e2); + boolean wrapInTernary = cond.e1 != ConstantExpression.I1 || cond.e2 != ConstantExpression.I0; +// boolean wrapInTernary = true; + Expression expr; +// 153, "IFEQ", +// 154, "IFNE", +// 155, "IFLT", +// 156, "IFGE", +// 157, "IFGT", +// 158, "IFLE", + switch (cond.opcode) { + case Opcodes.IFEQ: + expr = right; +// expr = new Expression.TernaryExpression(right, cond.e1, cond.e2); + break; + case Opcodes.IFNE: + expr = right; + +// expr = new Expression.TernaryExpression(right, cond.e1, cond.e2); +// if (expr instanceof Expression.TernaryExpression texpr +// && texpr.getWhenFalse() == Expression.ConstantExpression.I0) { +// expr = new Expression.BinaryExpression(texpr.getCond(), texpr.getWhenTrue(), "AND"); +// } +// wrapInTernary = false; + + expr = new UnaryExpression(expr, UnaryExpression.Operator.NOT); + break; + + default: { + var left = varStack.pop(); + switch (cond.opcode) { + case Opcodes.IF_ICMPEQ: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.NE); + break; + case Opcodes.IF_ICMPNE: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.EQ); + break; + case Opcodes.IF_ICMPLT: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.GE); + break; + case Opcodes.IF_ICMPGE: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.LT); + break; + case Opcodes.IF_ICMPGT: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.LE); + break; + case Opcodes.IF_ICMPLE: + expr = new BinaryExpression(left, right, BinaryExpression.Operator.GT); + break; + default: + throw new RuntimeException("jump insn: unsupported opcode " + ops.getOrDefault(cond.opcode, "" + cond.opcode)); + } + } + } + if (wrapInTernary) { + expr = new TernaryExpression(expr, cond.e1, cond.e2); + } + + varStack.push(expr); + debugExpr(); + } + + @Override + public void visitAttribute(Attribute attribute) { + System.out.println("attr: " + attribute); + super.visitAttribute(attribute); + debugExpr(); + } + + @Override + public void visitLdcInsn(Object value) { + System.out.println("ldc: " + value); + super.visitLdcInsn(value); + debugExpr(); + } + + @Override + public void visitEnd() { + System.out.println("end"); + + super.visitEnd(); + exprConsumer.accept(prediacteExpr); + } + + @Override + public void visitIntInsn(int opcode, int operand) { + System.out.println("intinsn: " + ops.getOrDefault(opcode, "" + opcode) + ", " + operand); + switch (opcode) { + case Opcodes.SIPUSH: + varStack.push(new ConstantExpression(operand)); + break; + default: + throw new RuntimeException("intinsn: unsupported opcode " + ops.getOrDefault(opcode, "" + opcode)); + } + super.visitIntInsn(opcode, operand); + debugExpr(); + } + + private Map ops = createOpsMap( + 3, "ICONST_0", + 4, "ICONST_1", + 17, "SIPUSH", + 25, "ALOAD", + + //jmp + 153, "IFEQ", + 154, "IFNE", + 155, "IFLT", + 156, "IFGE", + 157, "IFGT", + 158, "IFLE", + 159, "IF_ICMPEQ", + 160, "IF_ICMPNE", + 161, "IF_ICMPLT", + 162, "IF_ICMPGE", + 163, "IF_ICMPGT", + 164, "IF_ICMPLE", + 165, "IF_ACMPEQ", + 166, "IF_ACMPNE", + 167, "GOTO", + 168, "JSR", + 169, "RET", + + 172, "IRETURN", + + //field + 180, "GETFIELD", + 181, "PUTFIELD", + 182, "INVOKEVIRTUAL", + 183, "INVOKESPECIAL", + 184, "INVOKESTATIC", + 185, "INVOKEINTERFACE", + 186, "INVOKEDYNAMIC" + ); + + private static Map createOpsMap(Object... o) { + return IntStream.range(0, o.length / 2).boxed().collect(Collectors.toMap(i -> (int) o[i * 2], i -> (String) o[i * 2 + 1])); + } + + private void debugExpr() { + if (!varStack.isEmpty()) { + System.out.println("-------------------> " + new WhereExpression(new SelectExpression(List.of("*"), new TableExpression("dummy"), ""), varStack.peek())); + } + } +} diff --git a/src/main/java/jef/asm/PredicateParser.java b/src/main/java/jef/asm/PredicateParser.java new file mode 100644 index 0000000..8860a65 --- /dev/null +++ b/src/main/java/jef/asm/PredicateParser.java @@ -0,0 +1,60 @@ +package jef.asm; + +import jef.expressions.Expression; +import jef.serializable.SerializablePredicate; +import lombok.Getter; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.Opcodes; + +import java.io.InputStream; +import java.lang.invoke.SerializedLambda; +import java.util.stream.IntStream; + +@Getter +public class PredicateParser { + private final SerializablePredicate predicate; + + public PredicateParser(SerializablePredicate predicate) { + this.predicate = predicate; + } + + public Expression parse() throws AsmParseException { + try { + return parseExpression(); + } catch (Exception e) { + throw new AsmParseException("PredicateParser: failed to parse expression", e); + } + } + + private Expression parseExpression() throws Exception { + var cls = predicate.getClass(); + var loader = cls.getClassLoader(); + InputStream is; +// System.out.println(cls); +// System.out.println(cls.getName()); + if (cls.getName().contains("$$Lambda$")) { +// System.out.println(cls.getName().split("\\$\\$")[0].replace(".", "/") + ".class"); + is = loader.getResourceAsStream(cls.getName().split("\\$\\$")[0].replace(".", "/") + ".class"); + } else { +// System.out.println(cls.getName().replace(".", "/") + ".class"); + is = loader.getResourceAsStream(cls.getName().replace(".", "/") + ".class"); + } + var x = cls.getDeclaredMethod("writeReplace"); +// System.out.println(x); + x.setAccessible(true); + var serlambda = (SerializedLambda) x.invoke(predicate); + Object[] args = IntStream.range(0, serlambda.getCapturedArgCount()).mapToObj(serlambda::getCapturedArg).toArray(); +// System.out.println(serlambda); + + var lambdaname = serlambda.getImplMethodName(); +// System.out.println(lambdaname); + + var expr = new Expression[1]; + + var cr = new ClassReader(is); + var visiter = new FilterClassVisitor(Opcodes.ASM9, lambdaname, args, e -> expr[0] = e); + cr.accept(visiter, 0); + + return expr[0]; + } +} diff --git a/src/main/java/jef/expressions/AndExpression.java b/src/main/java/jef/expressions/AndExpression.java new file mode 100644 index 0000000..4d7e1bb --- /dev/null +++ b/src/main/java/jef/expressions/AndExpression.java @@ -0,0 +1,33 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.List; +import java.util.stream.Collectors; + +@Getter +@AllArgsConstructor +public class AndExpression implements Expression { + private final List exprs; + + public AndExpression(Expression... exprs) { + this.exprs = List.of(exprs); + } + + @Override + public Type getType() { + return Type.AND; + } + + @Override + public String toString() { + return exprs.stream().map(e -> { + if (e instanceof OrExpression) { + return "(" + e + ")"; + } else { + return e.toString(); + } + }).collect(Collectors.joining(" AND ")); + } +} diff --git a/src/main/java/jef/expressions/BinaryExpression.java b/src/main/java/jef/expressions/BinaryExpression.java new file mode 100644 index 0000000..eb6bf6a --- /dev/null +++ b/src/main/java/jef/expressions/BinaryExpression.java @@ -0,0 +1,60 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Map; + +@Getter +@AllArgsConstructor +public class BinaryExpression implements Expression { + private final Expression left; + private final Expression right; + private final Operator operator; + + @Override + public Type getType() { + return Type.BINARY; + } + + @Override + public String toString() { + return left + " " + operator + " " + right; + } + + @AllArgsConstructor + public enum Operator { + EQ("="), + NE("<>"), + LT("<"), + LE("<="), + GT(">"), + GE(">="), + // OR("OR"), +// AND("AND"), + IN("IN"), + ; + private final String string; + + @Override + public String toString() { + return string; + } + + private static final Map INVERSION = Map.of( + EQ, NE, + NE, EQ, + LT, GE, + GE, LT, + LE, GT, + GT, LE); + + public Operator invert() { + return INVERSION.get(this); + } + + public boolean isInvertible() { + return INVERSION.containsKey(this); + } + } +} diff --git a/src/main/java/jef/expressions/ConstantExpression.java b/src/main/java/jef/expressions/ConstantExpression.java new file mode 100644 index 0000000..1343b2a --- /dev/null +++ b/src/main/java/jef/expressions/ConstantExpression.java @@ -0,0 +1,23 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class ConstantExpression implements Expression { + public static final jef.expressions.ConstantExpression I0 = new jef.expressions.ConstantExpression(0); + public static final jef.expressions.ConstantExpression I1 = new jef.expressions.ConstantExpression(1); + + protected final Object value; + + @Override + public Type getType() { + return Type.CONSTANT; + } + + @Override + public String toString() { + return value.toString(); + } +} diff --git a/src/main/java/jef/expressions/Expression.java b/src/main/java/jef/expressions/Expression.java new file mode 100644 index 0000000..cc9063c --- /dev/null +++ b/src/main/java/jef/expressions/Expression.java @@ -0,0 +1,19 @@ +package jef.expressions; + +public interface Expression { + Type getType(); + + public enum Type { + AND, + BINARY, + CONSTANT, + FIELD, + OR, + PARAMETER, + SELECT, + TABLE, + TERNARY, + UNARY, + WHERE, + } +} diff --git a/src/main/java/jef/expressions/FieldExpression.java b/src/main/java/jef/expressions/FieldExpression.java new file mode 100644 index 0000000..edd0172 --- /dev/null +++ b/src/main/java/jef/expressions/FieldExpression.java @@ -0,0 +1,23 @@ +package jef.expressions; + +import lombok.Getter; + +@Getter +public class FieldExpression extends ConstantExpression implements Expression { + private final String name; + + public FieldExpression(String name) { + super(name); + this.name = name; + } + + @Override + public Type getType() { + return Type.FIELD; + } + + @Override + public String toString() { + return name; + } +} diff --git a/src/main/java/jef/expressions/OrExpression.java b/src/main/java/jef/expressions/OrExpression.java new file mode 100644 index 0000000..b5cc96d --- /dev/null +++ b/src/main/java/jef/expressions/OrExpression.java @@ -0,0 +1,27 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.List; +import java.util.stream.Collectors; + +@Getter +@AllArgsConstructor +public class OrExpression implements Expression { + private final List exprs; + + public OrExpression(Expression... exprs) { + this.exprs = List.of(exprs); + } + + @Override + public Type getType() { + return Type.OR; + } + + @Override + public String toString() { + return exprs.stream().map(Expression::toString).collect(Collectors.joining(" OR ")); + } +} diff --git a/src/main/java/jef/expressions/ParameterExpression.java b/src/main/java/jef/expressions/ParameterExpression.java new file mode 100644 index 0000000..fcc12ae --- /dev/null +++ b/src/main/java/jef/expressions/ParameterExpression.java @@ -0,0 +1,35 @@ +package jef.expressions; + +import lombok.Getter; + +import java.util.Collection; +import java.util.stream.Collectors; + +@Getter +public class ParameterExpression extends ConstantExpression implements Expression { + private final int index; + private final boolean isInput; + + public ParameterExpression(int index, Object value, boolean isInput) { + super(value); + this.index = index; + this.isInput = isInput; + } + + @Override + public Type getType() { + return Type.PARAMETER; + } + + @Override + public String toString() { + if (isInput) { + return "param #" + index; + } else if (this.value == null) { + return "null"; + } else if (this.value instanceof Collection) { + return "(" + ((Collection) this.value).stream().map(Object::toString).collect(Collectors.joining(",")) + ")"; + } + return value.toString(); + } +} diff --git a/src/main/java/jef/expressions/SelectExpression.java b/src/main/java/jef/expressions/SelectExpression.java new file mode 100644 index 0000000..c312ae8 --- /dev/null +++ b/src/main/java/jef/expressions/SelectExpression.java @@ -0,0 +1,27 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.List; +import java.util.stream.Collectors; + +@Getter +@AllArgsConstructor +public class SelectExpression implements Expression { + private final List fields; + private final Expression from; + private final String fromAlias; + + @Override + public Type getType() { + return Type.SELECT; + } + + @Override + public String toString() { + return "SELECT " + fields.stream().map(e -> e.equals("*") ? e : "`" + e + "`").collect(Collectors.joining(", ")) + " FROM " + + (!(from instanceof TableExpression) ? "(" + from + ")" : from) + + ((fromAlias == null || fromAlias.isBlank()) ? "" : " " + fromAlias); + } +} diff --git a/src/main/java/jef/expressions/TableExpression.java b/src/main/java/jef/expressions/TableExpression.java new file mode 100644 index 0000000..1080c09 --- /dev/null +++ b/src/main/java/jef/expressions/TableExpression.java @@ -0,0 +1,28 @@ +package jef.expressions; + +import lombok.Getter; + +@Getter +public class TableExpression extends ConstantExpression { + private final String name; + + public TableExpression(String name) { + super(name); + this.name = name; + } + + @Override + public Type getType() { + return Type.TABLE; + } + + @Override + public String getValue() { + return (String) super.getValue(); + } + + @Override + public String toString() { + return "`" + super.toString() + "`"; + } +} diff --git a/src/main/java/jef/expressions/TernaryExpression.java b/src/main/java/jef/expressions/TernaryExpression.java new file mode 100644 index 0000000..7dc8b92 --- /dev/null +++ b/src/main/java/jef/expressions/TernaryExpression.java @@ -0,0 +1,24 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class TernaryExpression implements Expression { + private final Expression cond; + private final Expression whenTrue; + private final Expression whenFalse; + + @Override + public Type getType() { + return Type.TERNARY; + } + + @Override + public String toString() { + return cond + + " ? " + (!(whenTrue instanceof ConstantExpression) ? "(" + whenTrue + ")" : whenTrue) + + " : " + (!(whenFalse instanceof ConstantExpression) ? "(" + whenFalse + ")" : whenFalse); + } +} diff --git a/src/main/java/jef/expressions/UnaryExpression.java b/src/main/java/jef/expressions/UnaryExpression.java new file mode 100644 index 0000000..d96543c --- /dev/null +++ b/src/main/java/jef/expressions/UnaryExpression.java @@ -0,0 +1,35 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class UnaryExpression implements Expression { + private final Expression expr; + private final Operator operator; + + @Override + public Type getType() { + return Type.UNARY; + } + + @Override + public String toString() { + return operator + " (" + expr + ")"; + } + + @AllArgsConstructor + public enum Operator { + NOT("NOT"), +// NEG("-"), +// POS("+"), + ; + private final String string; + + @Override + public String toString() { + return string; + } + } +} diff --git a/src/main/java/jef/expressions/WhereExpression.java b/src/main/java/jef/expressions/WhereExpression.java new file mode 100644 index 0000000..7dc8e63 --- /dev/null +++ b/src/main/java/jef/expressions/WhereExpression.java @@ -0,0 +1,21 @@ +package jef.expressions; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +@Getter +@AllArgsConstructor +public class WhereExpression implements Expression { + private Expression queryable; + private Expression where; + + @Override + public Type getType() { + return Type.WHERE; + } + + @Override + public String toString() { + return queryable + " WHERE " + where; + } +} diff --git a/src/main/java/jef/expressions/modifier/ExpressionModifier.java b/src/main/java/jef/expressions/modifier/ExpressionModifier.java new file mode 100644 index 0000000..49a991f --- /dev/null +++ b/src/main/java/jef/expressions/modifier/ExpressionModifier.java @@ -0,0 +1,89 @@ +package jef.expressions.modifier; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.FieldExpression; +import jef.expressions.OrExpression; +import jef.expressions.ParameterExpression; +import jef.expressions.SelectExpression; +import jef.expressions.TableExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; +import jef.expressions.WhereExpression; + +import java.util.ArrayList; + + +public abstract class ExpressionModifier { + + public Expression modify(Expression expr) { + return switch (expr.getType()) { + case AND -> modifyAnd((AndExpression) expr); + case BINARY -> modifyBinary((BinaryExpression) expr); + case CONSTANT -> modifyConstant((ConstantExpression) expr); + case FIELD -> modifyField((FieldExpression) expr); + case OR -> modifyOr((OrExpression) expr); + case PARAMETER -> modifyParameter((ParameterExpression) expr); + case SELECT -> modifySelect((SelectExpression) expr); + case TABLE -> modifyTable((TableExpression) expr); + case TERNARY -> modifyTernary((TernaryExpression) expr); + case UNARY -> modifyUnary((UnaryExpression) expr); + case WHERE -> modifyWhere((WhereExpression) expr); + default -> throw new IllegalStateException(); + }; + } + + public Expression modifyAnd(AndExpression expr) { + var exprs = new ArrayList(expr.getExprs().size()); + for (Expression e : expr.getExprs()) { + exprs.add(modify(e)); + } + return new AndExpression(exprs); + } + + public Expression modifyBinary(BinaryExpression expr) { + return new BinaryExpression(modify(expr.getLeft()), modify(expr.getRight()), expr.getOperator()); + } + + public Expression modifyConstant(ConstantExpression expr) { + return expr; + } + + public Expression modifyField(FieldExpression expr) { + return expr; + } + + public Expression modifyOr(OrExpression expr) { + var exprs = new ArrayList(expr.getExprs().size()); + for (Expression e : expr.getExprs()) { + exprs.add(modify(e)); + } + return new OrExpression(exprs); + } + + public Expression modifyParameter(ParameterExpression expr) { + return expr; + } + + public Expression modifySelect(SelectExpression expr) { + return new SelectExpression(expr.getFields(), modify(expr.getFrom()), expr.getFromAlias()); + } + + public Expression modifyTable(TableExpression expr) { + return expr; + } + + public Expression modifyTernary(TernaryExpression expr) { + return new TernaryExpression(modify(expr.getCond()), modify(expr.getWhenTrue()), modify(expr.getWhenFalse())); + } + + public Expression modifyUnary(UnaryExpression expr) { + return new UnaryExpression(modify(expr.getExpr()), expr.getOperator()); + } + + public Expression modifyWhere(WhereExpression expr) { + return new WhereExpression(modify(expr.getQueryable()), modify(expr.getWhere())); + } +} diff --git a/src/main/java/jef/expressions/modifier/ExpressionOptimizer.java b/src/main/java/jef/expressions/modifier/ExpressionOptimizer.java new file mode 100644 index 0000000..b7228ec --- /dev/null +++ b/src/main/java/jef/expressions/modifier/ExpressionOptimizer.java @@ -0,0 +1,102 @@ +package jef.expressions.modifier; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.OrExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; + +import java.util.ArrayList; + +public class ExpressionOptimizer extends ExpressionModifier { + @Override + public Expression modifyAnd(AndExpression expr) { + var ands = new ArrayList(expr.getExprs().size() * 2); + + //squash ands + for (Expression e : expr.getExprs()) { + if (e.getType() == Expression.Type.AND) { + ands.addAll(((AndExpression) e).getExprs()); + } else { + ands.add(e); + } + } + ands.replaceAll(this::modify); + + // x && false -> false + for (Expression e : ands) { + if (e == ConstantExpression.I0) { + return ConstantExpression.I0; + } + } + while (ands.remove(ConstantExpression.I1)) ; + + return new AndExpression(ands); + } + + @Override + public Expression modifyOr(OrExpression expr) { + var ors = new ArrayList(expr.getExprs().size() * 2); + + //squash ors + for (Expression e : expr.getExprs()) { + if (e.getType() == Expression.Type.OR) { + ors.addAll(((OrExpression) e).getExprs()); + } else { + ors.add(e); + } + } + ors.replaceAll(this::modify); + + // x || true -> true + for (Expression e : ors) { + if (e == ConstantExpression.I1) { + return ConstantExpression.I1; + } + } + while (ors.remove(ConstantExpression.I0)) ; + + return new OrExpression(ors); + } + + @Override + public Expression modifyTernary(TernaryExpression expr) { + if (expr.getWhenFalse() == ConstantExpression.I1 && expr.getWhenFalse() == ConstantExpression.I0) { + //x ? 1 : 0 -> x + return modify(expr.getCond()); + } else if (expr.getWhenFalse() == ConstantExpression.I0 && expr.getWhenFalse() == ConstantExpression.I1) { + //x ? 0 : 1 -> !x + return modify(new UnaryExpression(expr.getCond(), UnaryExpression.Operator.NOT)); + } else if (expr.getWhenFalse() == ConstantExpression.I0) { + //x ? y : 0 -> x && y + return modify(new AndExpression(expr.getCond(), expr.getWhenTrue())); + } else if (expr.getWhenFalse() == ConstantExpression.I1) { + //x ? y : 1 -> !x or y + return modify(new OrExpression(new UnaryExpression(expr.getCond(), UnaryExpression.Operator.NOT), expr.getWhenTrue())); + } else if (expr.getWhenTrue() instanceof TernaryExpression t && expr.getWhenFalse() == t.getWhenFalse()) { + // x ? (y ? z : u) : u -> (x && y) ? z : u + return new TernaryExpression(new AndExpression(expr.getCond(), t.getCond()), t.getWhenTrue(), t.getWhenFalse()); + } else { + return super.modifyTernary(expr); + } + } + + @Override + public Expression modifyUnary(UnaryExpression expr) { + if (expr.getExpr() instanceof UnaryExpression u + && expr.getOperator() == u.getOperator() + && expr.getOperator() == UnaryExpression.Operator.NOT) { + //!!x -> x + return modify(u.getExpr()); + } else if (expr.getExpr() instanceof BinaryExpression b + && expr.getOperator() == UnaryExpression.Operator.NOT + && b.getOperator().isInvertible()) { + //!(a < b) -> a >= b + return new BinaryExpression(b.getLeft(), b.getRight(), b.getOperator().invert()); + } else { + return super.modifyUnary(expr); + } + } +} diff --git a/src/main/java/jef/expressions/modifier/ExpressionOptimizerBottomUp.java b/src/main/java/jef/expressions/modifier/ExpressionOptimizerBottomUp.java new file mode 100644 index 0000000..05df653 --- /dev/null +++ b/src/main/java/jef/expressions/modifier/ExpressionOptimizerBottomUp.java @@ -0,0 +1,108 @@ +package jef.expressions.modifier; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.OrExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; + +import java.util.ArrayList; + +public class ExpressionOptimizerBottomUp extends ExpressionModifier { + @Override + public Expression modifyAnd(AndExpression expr) { + var andsOpt = expr.getExprs().stream().map(this::modify).toList(); + var ands = new ArrayList(expr.getExprs().size() * 2); + + //squash ands + for (Expression e : andsOpt) { + if (e.getType() == Expression.Type.AND) { + ands.addAll(((AndExpression) e).getExprs()); + } else { + ands.add(e); + } + } + ands.replaceAll(this::modify); + + // x && false -> false + for (Expression e : ands) { + if (e == ConstantExpression.I0) { + return ConstantExpression.I0; + } + } + while (ands.remove(ConstantExpression.I1)) ; + + return new AndExpression(ands); + } + + @Override + public Expression modifyOr(OrExpression expr) { + var orsOpt = expr.getExprs().stream().map(this::modify).toList(); + var ors = new ArrayList(expr.getExprs().size() * 2); + + //squash ors + for (Expression e : orsOpt) { + if (e.getType() == Expression.Type.OR) { + ors.addAll(((OrExpression) e).getExprs()); + } else { + ors.add(e); + } + } + ors.replaceAll(this::modify); + + // x || true -> true + for (Expression e : ors) { + if (e == ConstantExpression.I1) { + return ConstantExpression.I1; + } + } + while (ors.remove(ConstantExpression.I0)) ; + + return new OrExpression(ors); + } + + @Override + public Expression modifyTernary(TernaryExpression expr) { + var cond = modify(expr.getCond()); + var whenTrue = modify(expr.getWhenTrue()); + var whenFalse = modify(expr.getWhenFalse()); + if (whenTrue == ConstantExpression.I1 && whenFalse == ConstantExpression.I0) { + //x ? 1 : 0 -> x + return cond; + } else if (whenTrue == ConstantExpression.I0 && whenFalse == ConstantExpression.I1) { + //x ? 0 : 1 -> !x + return modify(new UnaryExpression(cond, UnaryExpression.Operator.NOT)); + } else if (whenFalse == ConstantExpression.I0) { + //x ? y : 0 -> x && y + return modify(new AndExpression(cond, whenTrue)); + } else if (whenFalse == ConstantExpression.I1) { + //x ? y : 1 -> !x or y + return modify(new OrExpression(new UnaryExpression(cond, UnaryExpression.Operator.NOT), whenTrue)); + } else if (whenTrue instanceof TernaryExpression t && whenFalse == t.getWhenFalse()) { + // x ? (y ? z : u) : u -> (x && y) ? z : u + return new TernaryExpression(new AndExpression(cond, t.getCond()), t.getWhenTrue(), t.getWhenFalse()); + } else { + return super.modifyTernary(expr); + } + } + + @Override + public Expression modifyUnary(UnaryExpression expr) { + var inner = modify(expr.getExpr()); + if (inner instanceof UnaryExpression u + && expr.getOperator() == u.getOperator() + && expr.getOperator() == UnaryExpression.Operator.NOT) { + //!!x -> x + return modify(u.getExpr()); + } else if (inner instanceof BinaryExpression b + && expr.getOperator() == UnaryExpression.Operator.NOT + && b.getOperator().isInvertible()) { + //!(a < b) -> a >= b + return new BinaryExpression(b.getLeft(), b.getRight(), b.getOperator().invert()); + } else { + return super.modifyUnary(expr); + } + } +} diff --git a/src/main/java/jef/expressions/modifier/TableAliasInjector.java b/src/main/java/jef/expressions/modifier/TableAliasInjector.java new file mode 100644 index 0000000..90da0ae --- /dev/null +++ b/src/main/java/jef/expressions/modifier/TableAliasInjector.java @@ -0,0 +1,19 @@ +package jef.expressions.modifier; + +import jef.expressions.Expression; +import jef.expressions.FieldExpression; + +public class TableAliasInjector extends ExpressionModifier { + private final String tableAlias; + private final String prefix; + + public TableAliasInjector(String tableAlias) { + this.tableAlias = tableAlias; + this.prefix = (tableAlias == null || tableAlias.isBlank()) ? "" : tableAlias + "."; + } + + @Override + public Expression modifyField(FieldExpression expr) { + return new FieldExpression(prefix + expr.getName()); + } +} diff --git a/src/main/java/jef/expressions/modifier/TernaryRewriter.java b/src/main/java/jef/expressions/modifier/TernaryRewriter.java new file mode 100644 index 0000000..4e3ddd1 --- /dev/null +++ b/src/main/java/jef/expressions/modifier/TernaryRewriter.java @@ -0,0 +1,20 @@ +package jef.expressions.modifier; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.OrExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; + +import java.util.ArrayList; + +public class TernaryRewriter extends ExpressionModifier { + @Override + public Expression modifyTernary(TernaryExpression expr) { + return new OrExpression(new AndExpression(expr.getCond(), expr.getWhenTrue()), + new AndExpression(new UnaryExpression(expr.getCond(), UnaryExpression.Operator.NOT), expr.getWhenFalse())); +// return new OrExpression(new AndExpression(expr.getCond(), expr.getWhenTrue()), expr.getWhenFalse()); + } +} diff --git a/src/main/java/jef/expressions/visitors/DebugExpressionVisitor.java b/src/main/java/jef/expressions/visitors/DebugExpressionVisitor.java new file mode 100644 index 0000000..561f412 --- /dev/null +++ b/src/main/java/jef/expressions/visitors/DebugExpressionVisitor.java @@ -0,0 +1,128 @@ +package jef.expressions.visitors; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.FieldExpression; +import jef.expressions.OrExpression; +import jef.expressions.ParameterExpression; +import jef.expressions.SelectExpression; +import jef.expressions.TableExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; +import jef.expressions.WhereExpression; + +import java.util.Collection; +import java.util.stream.Collectors; + +public class DebugExpressionVisitor extends ExpressionVisitor { + private int indent = 0; + + @Override + public void visitAnd(AndExpression expr) { + for (int i = 0; i < expr.getExprs().size(); i++) { + indent++; + visit(expr.getExprs().get(i)); + indent--; + if (i + 1 < expr.getExprs().size()) + System.out.println(i() + "AND"); + } + } + + @Override + public void visitBinary(BinaryExpression expr) { + indent++; + visit(expr.getLeft()); + indent--; + System.out.println(i() + expr.getOperator()); + indent++; + visit(expr.getRight()); + indent--; + } + + @Override + public void visitConstant(ConstantExpression expr) { + System.out.println(i() + expr.getValue()); + } + + @Override + public void visitField(FieldExpression expr) { + System.out.println(i() + expr.getName()); + } + + @Override + public void visitOr(OrExpression expr) { + for (int i = 0; i < expr.getExprs().size(); i++) { + indent++; + visit(expr.getExprs().get(i)); + indent--; + if (i + 1 < expr.getExprs().size()) + System.out.println(i() + "OR"); + } + } + + @Override + public void visitParameter(ParameterExpression expr) { + if (expr.getValue() instanceof Collection c) { + System.out.println(i() + c.stream().map(String::valueOf).collect(Collectors.joining(","))); + } else { + System.out.println(i() + expr.getValue()); + } + } + + @Override + public void visitSelect(SelectExpression expr) { + var table = expr.getFrom() instanceof TableExpression; + var tableName = table ? ((TableExpression) expr.getFrom()).getName() : null; + var tableAlias = (expr.getFromAlias() == null || expr.getFromAlias().isBlank()) ? "" : " " + expr.getFromAlias(); + System.out.println(i() + "SELECT " + expr.getFields().stream().collect(Collectors.joining(", ")) + + " FROM" + (!table ? " (" : tableName + " " + tableAlias)); + if (!table) { + indent++; + visit(expr.getFrom()); + indent--; + System.out.println(i() + ")" + tableAlias); + } + } + + @Override + public void visitTable(TableExpression expr) { + System.out.println(i() + expr.getValue()); + } + + @Override + public void visitTernary(TernaryExpression expr) { + indent++; + visit(expr.getCond()); + indent--; + System.out.println(i() + "?"); + indent++; + visit(expr.getWhenTrue()); + indent--; + System.out.println(i() + ":"); + indent++; + visit(expr.getWhenFalse()); + indent--; + } + + @Override + public void visitUnary(UnaryExpression expr) { + System.out.println(i() + expr.getOperator()); + indent++; + visit(expr.getExpr()); + indent--; + } + + @Override + public void visitWhere(WhereExpression expr) { + visit(expr.getQueryable()); + System.out.println(i() + "WHERE"); + indent++; + visit(expr.getWhere()); + indent--; + } + + private String i() { + return " ".repeat(indent); + } +} diff --git a/src/main/java/jef/expressions/visitors/ExpressionVisitor.java b/src/main/java/jef/expressions/visitors/ExpressionVisitor.java new file mode 100644 index 0000000..7376640 --- /dev/null +++ b/src/main/java/jef/expressions/visitors/ExpressionVisitor.java @@ -0,0 +1,81 @@ +package jef.expressions.visitors; + +import jef.expressions.AndExpression; +import jef.expressions.BinaryExpression; +import jef.expressions.ConstantExpression; +import jef.expressions.Expression; +import jef.expressions.FieldExpression; +import jef.expressions.OrExpression; +import jef.expressions.ParameterExpression; +import jef.expressions.SelectExpression; +import jef.expressions.TableExpression; +import jef.expressions.TernaryExpression; +import jef.expressions.UnaryExpression; +import jef.expressions.WhereExpression; + +public abstract class ExpressionVisitor { + public void visit(Expression expr) { + switch (expr.getType()) { + case AND -> visitAnd((AndExpression) expr); + case BINARY -> visitBinary((BinaryExpression) expr); + case CONSTANT -> visitConstant((ConstantExpression) expr); + case FIELD -> visitField((FieldExpression) expr); + case OR -> visitOr((OrExpression) expr); + case PARAMETER -> visitParameter((ParameterExpression) expr); + case SELECT -> visitSelect((SelectExpression) expr); + case TABLE -> visitTable((TableExpression) expr); + case TERNARY -> visitTernary((TernaryExpression) expr); + case UNARY -> visitUnary((UnaryExpression) expr); + case WHERE -> visitWhere((WhereExpression) expr); + default -> throw new IllegalStateException(); + } + } + + public void visitAnd(AndExpression expr) { + for (Expression e : expr.getExprs()) { + visit(e); + } + } + + public void visitBinary(BinaryExpression expr) { + visit(expr.getLeft()); + visit(expr.getRight()); + } + + public void visitConstant(ConstantExpression expr) { + } + + public void visitField(FieldExpression expr) { + } + + public void visitOr(OrExpression expr) { + for (Expression e : expr.getExprs()) { + visit(e); + } + } + + public void visitParameter(ParameterExpression expr) { + } + + public void visitSelect(SelectExpression expr) { + visit(expr.getFrom()); + } + + public void visitTable(TableExpression expr) { + } + + public void visitTernary(TernaryExpression expr) { + visit(expr.getCond()); + visit(expr.getWhenTrue()); + visit(expr.getWhenFalse()); + } + + public void visitUnary(UnaryExpression expr) { + visit(expr.getExpr()); + } + + public void visitWhere(WhereExpression expr) { + visit(expr.getQueryable()); + visit(expr.getWhere()); + } +} diff --git a/src/main/java/jef/operations/FilterOp.java b/src/main/java/jef/operations/FilterOp.java new file mode 100644 index 0000000..afc8632 --- /dev/null +++ b/src/main/java/jef/operations/FilterOp.java @@ -0,0 +1,58 @@ +package jef.operations; + +import jef.Queryable; +import jef.asm.AsmParseException; +import jef.asm.PredicateParser; +import jef.expressions.Expression; +import jef.expressions.SelectExpression; +import jef.expressions.WhereExpression; +import jef.expressions.modifier.ExpressionOptimizer; +import jef.expressions.modifier.ExpressionOptimizerBottomUp; +import jef.expressions.modifier.TableAliasInjector; +import jef.expressions.modifier.TernaryRewriter; +import jef.serializable.SerializablePredicate; + +import java.io.Serializable; +import java.util.List; +import java.util.function.Predicate; + +public class FilterOp implements Queryable, Operation { + + private final Queryable queryable; + private final Predicate predicate; + private final Expression predicateExpr; + + public FilterOp(Queryable queryable, SerializablePredicate predicate) { + this.queryable = queryable; + this.predicate = predicate; + var parser = new PredicateParser(predicate); + Expression expr; + try { + expr = parser.parse(); + } catch (AsmParseException e) { + throw new RuntimeException(e); + } + System.out.println(expr); + expr = new TernaryRewriter().modify(expr); + System.out.println(expr); + expr = new ExpressionOptimizer().modify(expr); +// expr = new ExpressionOptimizerBottomUp().modify(expr); + expr = new TableAliasInjector(getTableAlias()).modify(expr); + this.predicateExpr = expr; + } + + @Override + public String getTableAlias() { + return String.valueOf((char) (queryable.getTableAlias().charAt(0) + (char) 1)); + } + + @Override + public Expression getExpression() { + return new WhereExpression(new SelectExpression(List.of("*"), queryable.getExpression(), getTableAlias()), predicateExpr); + } + + @Override + public String toString() { + return getExpression().toString(); + } +} diff --git a/src/main/java/jef/operations/Operation.java b/src/main/java/jef/operations/Operation.java new file mode 100644 index 0000000..35184ee --- /dev/null +++ b/src/main/java/jef/operations/Operation.java @@ -0,0 +1,9 @@ +package jef.operations; + +import jef.Queryable; + +import java.io.Serializable; + +public interface Operation extends Queryable { + +} diff --git a/src/main/java/jef/serializable/SerializableFunction.java b/src/main/java/jef/serializable/SerializableFunction.java new file mode 100644 index 0000000..997c502 --- /dev/null +++ b/src/main/java/jef/serializable/SerializableFunction.java @@ -0,0 +1,7 @@ +package jef.serializable; + +import java.io.Serializable; +import java.util.function.Function; + +public interface SerializableFunction extends Function, Serializable { +} diff --git a/src/main/java/jef/serializable/SerializablePredicate.java b/src/main/java/jef/serializable/SerializablePredicate.java new file mode 100644 index 0000000..f4755fe --- /dev/null +++ b/src/main/java/jef/serializable/SerializablePredicate.java @@ -0,0 +1,7 @@ +package jef.serializable; + +import java.io.Serializable; +import java.util.function.Predicate; + +public interface SerializablePredicate extends Predicate, Serializable { +} diff --git a/src/test/java/jef/operations/FilterOpTest.java b/src/test/java/jef/operations/FilterOpTest.java new file mode 100644 index 0000000..0080443 --- /dev/null +++ b/src/test/java/jef/operations/FilterOpTest.java @@ -0,0 +1,144 @@ +package jef.operations; + +import jef.DBSet; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.Serializable; +import java.util.List; + +public class FilterOpTest { + + @Test + public void testTrue() { +// String act; +// act = new DBSet("table1") +// .filter(e -> true) +// .toString(); +// Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE 1", act); + } + + @Test + public void testCompareWithEntityMember() { + String act; + act = new DBSet("table1") + .filter(e -> e.b == 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b = 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b != 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b <> 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b < 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b < 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b > 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b > 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b <= 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b <= 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b >= 1) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b >= 1", act); + + + act = new DBSet("table1") + .filter(e -> e.b == 1337) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b = 1337", act); + } + + @Test + public void testContainsWithEntityMember() { + var s = List.of(1, 3); + String act; + act = new DBSet("table1") + .filter(e -> s.contains(e.b)) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b IN (1,3)", act); + } + + @Test + public void testMultipleFilter() { + String act; + var s = List.of(1, 3); + act = new DBSet("table1") + .filter(e -> s.contains(e.b)) + .filter(e -> e.b == 1337) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b IN (1,3)) b WHERE b.b = 1337", act); + } + + @Test + public void testComplexExpression() { + String act; + var s = List.of(1, 3); + act = new DBSet("table1") + .filter(e -> s.contains(e.b) && e.b == 1337) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b IN (1,3) AND a.b = 1337", act); + + + act = new DBSet("table1") + .filter(e -> s.contains(e.b) && e.b == 1337 && e.b == 420) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b IN (1,3) AND a.b = 1337 AND a.b = 420", act); + + + act = new DBSet("table1") + .filter(e -> e.b == 1337 || e.b != 420 || s.contains(e.b)) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b = 1337 OR a.b <> 420 OR a.b IN (1,3)", act); + + + act = new DBSet("table1") + .filter(e -> e.b == 1337 || e.b == 420 || s.contains(e.b)) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b = 1337 OR a.b = 420 OR a.b IN (1,3)", act); + + act = new DBSet("table1") + .filter(e -> e.b == 1337 || s.contains(e.b) || e.b == 420) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b = 1337 OR a.b IN (1,3) OR a.b = 420", act); + } + + @Test + public void testComplexExpressionMixedAndOr() { + String act; + var s = List.of(1, 3); + act = new DBSet("table1") + .filter(e -> s.contains(e.b) && (e.b == 1337 || e.b == 420)) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE a.b IN (1,3) AND (a.b = 1337 OR a.b = 420)", act); + + + act = new DBSet("table1") + .filter(e -> (e.b == 1337 || e.b == 420) && s.contains(e.b)) + .toString(); + Assertions.assertEquals("SELECT * FROM (SELECT * FROM `table1`) a WHERE (a.b = 1337 OR a.b = 420) AND a.b IN (1,3)", act); + } + + @Test + public void test() { + + } + + public static class TestClass implements Serializable { + public int b = 1; + } +} diff --git a/src/test/java/jef/visitors/DebugExpressionVisitorTest.java b/src/test/java/jef/visitors/DebugExpressionVisitorTest.java new file mode 100644 index 0000000..a027b07 --- /dev/null +++ b/src/test/java/jef/visitors/DebugExpressionVisitorTest.java @@ -0,0 +1,23 @@ +package jef.visitors; + +import jef.DBSet; +import jef.Queryable; +import jef.expressions.visitors.DebugExpressionVisitor; +import jef.operations.FilterOpTest; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class DebugExpressionVisitorTest { + @Test + public void test() { + var s = List.of(1, 3); + Queryable q = new DBSet("table1") + .filter(e -> s.contains(e.b) && e.b == 1337 && e.b == 420); + new DebugExpressionVisitor().visit(q.getExpression()); + + Queryable q2 = new DBSet("table1") + .filter(e -> s.contains(e.b) || e.b == 1337); + new DebugExpressionVisitor().visit(q2.getExpression()); + } +} \ No newline at end of file