diff options
author | kpandit <kpandit@pinterest.com> | 2021-11-20 00:56:17 +0100 |
---|---|---|
committer | Jens Geyer <jensg@apache.org> | 2021-11-20 00:57:57 +0100 |
commit | 5a9d139be4ef1a790da4c6f25377b8ab6573a325 (patch) | |
tree | 0377e631e8faada4e4c7cb1db3e47af89518a650 /lib | |
parent | 2c0927826d1e7f7e902f29a925e22058f949f535 (diff) | |
download | thrift-5a9d139be4ef1a790da4c6f25377b8ab6573a325.tar.gz |
THRIFT-5443: add support for partial Thrift deserialization
Client: java
Patch: Bhalchandra Pandit
This closes #2439
Diffstat (limited to 'lib')
28 files changed, 5037 insertions, 10 deletions
diff --git a/lib/java/gradle.properties b/lib/java/gradle.properties index 6faaa2a09..fdd7e1543 100644 --- a/lib/java/gradle.properties +++ b/lib/java/gradle.properties @@ -33,3 +33,4 @@ tomcat.embed.version=9.0.43 junit.version=4.12 mockito.version=1.10.19 javax.annotation.version=1.3.2 +commons-lang3.version=3.12
\ No newline at end of file diff --git a/lib/java/gradle/environment.gradle b/lib/java/gradle/environment.gradle index b6cfb2123..12fee154f 100644 --- a/lib/java/gradle/environment.gradle +++ b/lib/java/gradle/environment.gradle @@ -69,6 +69,7 @@ dependencies { compile "org.apache.httpcomponents:httpcore:${httpcoreVersion}" compile "javax.servlet:javax.servlet-api:${servletVersion}" compile "javax.annotation:javax.annotation-api:${javaxAnnotationVersion}" + compile "org.apache.commons:commons-lang3:3.12.0" testCompile "junit:junit:${junitVersion}" testCompile "org.mockito:mockito-all:${mockitoVersion}" diff --git a/lib/java/gradle/generateTestThrift.gradle b/lib/java/gradle/generateTestThrift.gradle index 121bf537d..4b712ca23 100644 --- a/lib/java/gradle/generateTestThrift.gradle +++ b/lib/java/gradle/generateTestThrift.gradle @@ -81,6 +81,7 @@ task generateJava(group: 'Build') { thriftCompile(it, 'JavaDeepCopyTest.thrift') thriftCompile(it, 'EnumContainersTest.thrift') thriftCompile(it, 'JavaBinaryDefault.thrift') + thriftCompile(it, 'partial/thrift_test_schema.thrift') } task generateBeanJava(group: 'Build') { diff --git a/lib/java/src/org/apache/thrift/TDeserializer.java b/lib/java/src/org/apache/thrift/TDeserializer.java index fc8cb8332..1433f6240 100644 --- a/lib/java/src/org/apache/thrift/TDeserializer.java +++ b/lib/java/src/org/apache/thrift/TDeserializer.java @@ -19,18 +19,29 @@ package org.apache.thrift; -import java.io.UnsupportedEncodingException; -import java.nio.ByteBuffer; - +import org.apache.thrift.meta_data.EnumMetaData; +import org.apache.thrift.meta_data.StructMetaData; +import org.apache.thrift.partial.TFieldData; +import org.apache.thrift.partial.ThriftFieldValueProcessor; +import org.apache.thrift.partial.ThriftMetadata; +import org.apache.thrift.partial.ThriftStructProcessor; +import org.apache.thrift.partial.Validate; import org.apache.thrift.protocol.TBinaryProtocol; import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TList; +import org.apache.thrift.protocol.TMap; import org.apache.thrift.protocol.TProtocol; import org.apache.thrift.protocol.TProtocolFactory; import org.apache.thrift.protocol.TProtocolUtil; +import org.apache.thrift.protocol.TSet; import org.apache.thrift.protocol.TType; import org.apache.thrift.transport.TMemoryInputTransport; import org.apache.thrift.transport.TTransportException; +import java.io.UnsupportedEncodingException; +import java.nio.ByteBuffer; +import java.util.Collection; + /** * Generic utility for easily deserializing objects from a byte array or Java * String. @@ -40,6 +51,12 @@ public class TDeserializer { private final TProtocol protocol_; private final TMemoryInputTransport trans_; + // Metadata that describes fields to deserialize during partial deserialization. + private ThriftMetadata.ThriftStruct metadata_ = null; + + // Processor that handles deserialized field values during partial deserialization. + private ThriftFieldValueProcessor processor_ = null; + /** * Create a new TDeserializer that uses the TBinaryProtocol by default. * @@ -62,6 +79,54 @@ public class TDeserializer { } /** + * Construct a new TDeserializer that supports partial deserialization + * that outputs instances of type controlled by the given {@code processor}. + * + * @param thriftClass a TBase derived class. + * @param fieldNames list of fields to deserialize. + * @param processor the Processor that handles deserialized field values. + * @param protocolFactory the Factory to create a protocol. + */ + public TDeserializer( + Class<? extends TBase> thriftClass, + Collection<String> fieldNames, + ThriftFieldValueProcessor processor, + TProtocolFactory protocolFactory) throws TTransportException { + this(protocolFactory); + + Validate.checkNotNull(thriftClass, "thriftClass"); + Validate.checkNotNull(fieldNames, "fieldNames"); + Validate.checkNotNull(processor, "processor"); + + metadata_ = ThriftMetadata.ThriftStruct.fromFieldNames(thriftClass, fieldNames); + processor_ = processor; + } + + /** + * Construct a new TDeserializer that supports partial deserialization + * that outputs {@code TBase} instances. + * + * @param thriftClass a TBase derived class. + * @param fieldNames list of fields to deserialize. + * @param protocolFactory the Factory to create a protocol. + */ + public TDeserializer( + Class<? extends TBase> thriftClass, + Collection<String> fieldNames, + TProtocolFactory protocolFactory) throws TTransportException { + this(thriftClass, fieldNames, new ThriftStructProcessor(), protocolFactory); + } + + /** + * Gets the metadata used for partial deserialization. + * + * @return the metadata used for partial deserialization. + */ + public ThriftMetadata.ThriftStruct getMetadata() { + return metadata_; + } + + /** * Deserialize the Thrift object from a byte array. * * @param base The object to read into @@ -82,12 +147,16 @@ public class TDeserializer { * @throws TException if an error is encountered during deserialization. */ public void deserialize(TBase base, byte[] bytes, int offset, int length) throws TException { - try { - trans_.reset(bytes, offset, length); - base.read(protocol_); - } finally { - trans_.clear(); - protocol_.reset(); + if (this.isPartialDeserializationMode()) { + this.partialDeserializeThriftObject(base, bytes, offset, length); + } else { + try { + trans_.reset(bytes, offset, length); + base.read(protocol_); + } finally { + trans_.clear(); + protocol_.reset(); + } } } @@ -353,4 +422,305 @@ public class TDeserializer { public void fromString(TBase base, String data) throws TException { deserialize(base, data.getBytes()); } + + // ---------------------------------------------------------------------- + // Methods related to partial deserialization. + + /** + * Partially deserializes the given serialized blob. + * + * @param bytes the serialized blob. + * @return deserialized instance. + * @throws TException if an error is encountered during deserialization. + */ + public Object partialDeserializeObject(byte[] bytes) throws TException { + return this.partialDeserializeObject(bytes, 0, bytes.length); + } + + /** + * Partially deserializes the given serialized blob into the given {@code TBase} instance. + * + * @param base the instance into which the given blob is deserialized. + * @param bytes the serialized blob. + * @param offset the blob is read starting at this offset. + * @param length the size of blob read (in number of bytes). + * @return deserialized instance. + * @throws TException if an error is encountered during deserialization. + */ + public Object partialDeserializeThriftObject(TBase base, byte[] bytes, int offset, int length) + throws TException { + ensurePartialThriftDeserializationMode(); + + return this.partialDeserializeObject(base, bytes, offset, length); + } + + /** + * Partially deserializes the given serialized blob. + * + * @param bytes the serialized blob. + * @param offset the blob is read starting at this offset. + * @param length the size of blob read (in number of bytes). + * @return deserialized instance. + * @throws TException if an error is encountered during deserialization. + */ + public Object partialDeserializeObject(byte[] bytes, int offset, int length) throws TException { + ensurePartialDeserializationMode(); + + return this.partialDeserializeObject(null, bytes, offset, length); + } + + /** + * Partially deserializes the given serialized blob. + * + * @param instance the instance into which the given blob is deserialized. + * @param bytes the serialized blob. + * @param offset the blob is read starting at this offset. + * @param length the size of blob read (in number of bytes). + * @return deserialized instance. + * @throws TException if an error is encountered during deserialization. + */ + private Object partialDeserializeObject(Object instance, byte[] bytes, int offset, int length) + throws TException { + ensurePartialDeserializationMode(); + + this.trans_.reset(bytes, offset, length); + this.protocol_.reset(); + return this.deserializeStruct(instance, this.metadata_); + } + + private Object deserialize(ThriftMetadata.ThriftObject data) throws TException { + + Object value; + byte fieldType = data.data.valueMetaData.type; + switch (fieldType) { + case TType.STRUCT: + return this.deserializeStruct(null, (ThriftMetadata.ThriftStruct) data); + + case TType.LIST: + return this.deserializeList((ThriftMetadata.ThriftList) data); + + case TType.MAP: + return this.deserializeMap((ThriftMetadata.ThriftMap) data); + + case TType.SET: + return this.deserializeSet((ThriftMetadata.ThriftSet) data); + + case TType.ENUM: + return this.deserializeEnum((ThriftMetadata.ThriftEnum) data); + + case TType.BOOL: + return this.protocol_.readBool(); + + case TType.BYTE: + return this.protocol_.readByte(); + + case TType.I16: + return this.protocol_.readI16(); + + case TType.I32: + return this.protocol_.readI32(); + + case TType.I64: + return this.protocol_.readI64(); + + case TType.DOUBLE: + return this.protocol_.readDouble(); + + case TType.STRING: + if (((ThriftMetadata.ThriftPrimitive) data).isBinary()) { + return this.processor_.prepareBinary(this.protocol_.readBinary()); + } else { + return this.processor_.prepareString(this.protocol_.readBinary()); + } + + default: + throw unsupportedFieldTypeException(fieldType); + } + } + + private Object deserializeStruct(Object instance, ThriftMetadata.ThriftStruct data) + throws TException { + + if (instance == null) { + instance = this.processor_.createNewStruct(data); + } + + this.protocol_.readStructBegin(); + while (true) { + int tfieldData = this.protocol_.readFieldBeginData(); + byte tfieldType = TFieldData.getType(tfieldData); + if (tfieldType == TType.STOP) { + break; + } + + Integer id = (int) TFieldData.getId(tfieldData); + ThriftMetadata.ThriftObject field = (ThriftMetadata.ThriftObject) data.fields.get(id); + + if (field != null) { + this.deserializeStructField(instance, field.fieldId, field); + } else { + this.protocol_.skip(tfieldType); + } + this.protocol_.readFieldEnd(); + } + this.protocol_.readStructEnd(); + + return this.processor_.prepareStruct(instance); + } + + private void deserializeStructField( + Object instance, + TFieldIdEnum fieldId, + ThriftMetadata.ThriftObject data) throws TException { + + byte fieldType = data.data.valueMetaData.type; + Object value; + + switch (fieldType) { + case TType.BOOL: + this.processor_.setBool(instance, fieldId, this.protocol_.readBool()); + break; + + case TType.BYTE: + this.processor_.setByte(instance, fieldId, this.protocol_.readByte()); + break; + + case TType.I16: + this.processor_.setInt16(instance, fieldId, this.protocol_.readI16()); + break; + + case TType.I32: + this.processor_.setInt32(instance, fieldId, this.protocol_.readI32()); + break; + + case TType.I64: + this.processor_.setInt64(instance, fieldId, this.protocol_.readI64()); + break; + + case TType.DOUBLE: + this.processor_.setDouble(instance, fieldId, this.protocol_.readDouble()); + break; + + case TType.STRING: + if (((ThriftMetadata.ThriftPrimitive) data).isBinary()) { + this.processor_.setBinary(instance, fieldId, this.protocol_.readBinary()); + } else { + this.processor_.setString(instance, fieldId, this.protocol_.readBinary()); + } + break; + + case TType.STRUCT: + value = this.deserializeStruct(null, (ThriftMetadata.ThriftStruct) data); + this.processor_.setStructField(instance, fieldId, value); + break; + + case TType.LIST: + value = this.deserializeList((ThriftMetadata.ThriftList) data); + this.processor_.setListField(instance, fieldId, value); + break; + + case TType.MAP: + value = this.deserializeMap((ThriftMetadata.ThriftMap) data); + this.processor_.setMapField(instance, fieldId, value); + break; + + case TType.SET: + value = this.deserializeSet((ThriftMetadata.ThriftSet) data); + this.processor_.setSetField(instance, fieldId, value); + break; + + case TType.ENUM: + value = this.deserializeEnum((ThriftMetadata.ThriftEnum) data); + this.processor_.setEnumField(instance, fieldId, value); + break; + + default: + throw new RuntimeException("Unsupported field type: " + fieldId.toString()); + } + } + + private Object deserializeList(ThriftMetadata.ThriftList data) throws TException { + + TList tlist = this.protocol_.readListBegin(); + Object instance = this.processor_.createNewList(tlist.size); + for (int i = 0; i < tlist.size; i++) { + Object value = this.deserialize(data.elementData); + this.processor_.setListElement(instance, i, value); + } + this.protocol_.readListEnd(); + return this.processor_.prepareList(instance); + } + + private Object deserializeMap(ThriftMetadata.ThriftMap data) throws TException { + TMap tmap = this.protocol_.readMapBegin(); + Object instance = this.processor_.createNewMap(tmap.size); + for (int i = 0; i < tmap.size; i++) { + Object key = this.deserialize(data.keyData); + Object val = this.deserialize(data.valueData); + this.processor_.setMapElement(instance, i, key, val); + } + this.protocol_.readMapEnd(); + return this.processor_.prepareMap(instance); + } + + private Object deserializeSet(ThriftMetadata.ThriftSet data) throws TException { + TSet tset = this.protocol_.readSetBegin(); + Object instance = this.processor_.createNewSet(tset.size); + for (int i = 0; i < tset.size; i++) { + Object eltValue = this.deserialize(data.elementData); + this.processor_.setSetElement(instance, i, eltValue); + } + this.protocol_.readSetEnd(); + return this.processor_.prepareSet(instance); + } + + private Object deserializeEnum(ThriftMetadata.ThriftEnum data) throws TException { + int ordinal = this.protocol_.readI32(); + Class<? extends TEnum> enumClass = ((EnumMetaData) data.data.valueMetaData).enumClass; + return this.processor_.prepareEnum(enumClass, ordinal); + } + + private <T extends TBase> T createNewStruct(ThriftMetadata.ThriftStruct data) { + T instance = null; + + try { + instance = (T) this.getStructClass(data).newInstance(); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + + return instance; + } + + private <T extends TBase> Class<T> getStructClass(ThriftMetadata.ThriftStruct data) { + return (Class<T>) ((StructMetaData) data.data.valueMetaData).structClass; + } + + private static UnsupportedOperationException unsupportedFieldTypeException(byte fieldType) { + return new UnsupportedOperationException("field type not supported: " + fieldType); + } + + private boolean isPartialDeserializationMode() { + return (this.metadata_ != null) && (this.processor_ != null); + } + + private void ensurePartialDeserializationMode() throws IllegalStateException { + if (!this.isPartialDeserializationMode()) { + throw new IllegalStateException( + "Members metadata and processor must be correctly initialized in order to use this method" + ); + } + } + + private void ensurePartialThriftDeserializationMode() throws IllegalStateException { + this.ensurePartialDeserializationMode(); + + if (!(this.processor_ instanceof ThriftStructProcessor)) { + throw new IllegalStateException( + "processor must be an instance of ThriftStructProcessor to use this method" + ); + } + } } diff --git a/lib/java/src/org/apache/thrift/partial/EnumCache.java b/lib/java/src/org/apache/thrift/partial/EnumCache.java new file mode 100644 index 000000000..22423f10c --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/EnumCache.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.partial.Validate; + +import org.apache.thrift.TEnum; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +/** + * Provides a memoized way to lookup an enum by its value. + * + * This class is used internally by {@code TDeserializer}. + * It is not intended to be used separately on its own. + */ +public class EnumCache { + private static Logger LOG = LoggerFactory.getLogger(EnumCache.class); + + private Map<Class<? extends TEnum>, Map<Integer, TEnum>> classMap; + + public EnumCache() { + this.classMap = new HashMap<>(); + } + + /** + * Gets an instance of the enum type {@code enumClass} + * corresponding to the given {@code value}. + * + * @param enumClass class of the enum to be returned. + * @param value value returned by {@code getValue()}. + */ + public TEnum get(Class<? extends TEnum> enumClass, int value) { + Validate.checkNotNull(enumClass, "enumClass"); + + Map<Integer, TEnum> valueMap = classMap.get(enumClass); + if (valueMap == null) { + valueMap = addClass(enumClass); + if (valueMap == null) { + return null; + } + } + + return valueMap.get(value); + } + + private Map<Integer, TEnum> addClass(Class<? extends TEnum> enumClass) { + try { + Method valuesMethod = enumClass.getMethod("values"); + TEnum[] enumValues = (TEnum[]) valuesMethod.invoke(null); + Map<Integer, TEnum> valueMap = new HashMap<>(); + + for (TEnum enumValue : enumValues) { + valueMap.put(enumValue.getValue(), enumValue); + } + + classMap.put(enumClass, valueMap); + return valueMap; + } catch (NoSuchMethodException e) { + LOG.error("enum class does not have values() method", e); + return null; + } catch (IllegalAccessException e) { + LOG.error("Enum.values() method should be public!", e); + return null; + } catch (InvocationTargetException e) { + LOG.error("Enum.values() threw exception", e); + return null; + } + } +} diff --git a/lib/java/src/org/apache/thrift/partial/PartialThriftComparer.java b/lib/java/src/org/apache/thrift/partial/PartialThriftComparer.java new file mode 100644 index 000000000..f0f33eb4f --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/PartialThriftComparer.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.TBase; +import org.apache.thrift.protocol.TType; + +import java.lang.StringBuilder; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Enables comparison of two TBase instances such that the comparison + * is limited to the subset of fields defined by the supplied metadata. + * + * This comparer is useful when comparing two instances where: + * -- one is generated by full deserialization. + * -- the other is generated by partial deserialization. + * + * The typical use case is to establish correctness of partial deserialization. + */ +public class PartialThriftComparer<T extends TBase> { + + private enum ComparisonResult { + UNKNOWN, + EQUAL, + NOT_EQUAL + } + + // Metadata that defines the scope of comparison. + private ThriftMetadata.ThriftStruct metadata; + + /** + * Constructs an instance of {@link PartialThriftComparer}. + * + * @param metadata defines the scope of comparison. + */ + public PartialThriftComparer(ThriftMetadata.ThriftStruct metadata) { + this.metadata = metadata; + } + + /** + * Compares thrift objects {@code t1} and {@code t2} and + * returns true if they are equal false otherwise. The comparison is limited + * to the scope defined by {@code metadata}. + * <p> + * If the objects are not equal then it optionally records their differences + * if {@code sb} is supplied. + * <p> + * + * @param t1 the first object. + * @param t2 the second object. + * @param sb if non-null, results of the comparison are returned in it. + * @return true if objects are equivalent, false otherwise. + */ + public boolean areEqual(T t1, T t2, StringBuilder sb) { + return this.areEqual(this.metadata, t1, t2, sb); + } + + private boolean areEqual( + ThriftMetadata.ThriftObject data, + Object o1, + Object o2, + StringBuilder sb) { + + byte fieldType = data.data.valueMetaData.type; + switch (fieldType) { + case TType.STRUCT: + return this.areEqual((ThriftMetadata.ThriftStruct) data, o1, o2, sb); + + case TType.LIST: + return this.areEqual((ThriftMetadata.ThriftList) data, o1, o2, sb); + + case TType.MAP: + return this.areEqual((ThriftMetadata.ThriftMap) data, o1, o2, sb); + + case TType.SET: + return this.areEqual((ThriftMetadata.ThriftSet) data, o1, o2, sb); + + case TType.ENUM: + return this.areEqual((ThriftMetadata.ThriftEnum) data, o1, o2, sb); + + case TType.BOOL: + case TType.BYTE: + case TType.I16: + case TType.I32: + case TType.I64: + case TType.DOUBLE: + case TType.STRING: + return this.areEqual((ThriftMetadata.ThriftPrimitive) data, o1, o2, sb); + + default: + throw unsupportedFieldTypeException(fieldType); + } + } + + private boolean areEqual( + ThriftMetadata.ThriftStruct data, + Object o1, + Object o2, + StringBuilder sb) { + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + TBase t1 = (TBase) o1; + TBase t2 = (TBase) o2; + + if (data.fields.size() == 0) { + if (t1.equals(t2)) { + return true; + } else { + appendNotEqual(data, sb, t1, t2, "struct1", "struct2"); + return false; + } + } else { + + boolean overallResult = true; + + for (Object o : data.fields.values()) { + ThriftMetadata.ThriftObject field = (ThriftMetadata.ThriftObject) o; + Object f1 = t1.getFieldValue(field.fieldId); + Object f2 = t2.getFieldValue(field.fieldId); + overallResult = overallResult && this.areEqual(field, f1, f2, sb); + } + + return overallResult; + } + } + + private boolean areEqual( + ThriftMetadata.ThriftPrimitive data, + Object o1, + Object o2, + StringBuilder sb) { + + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + if (data.isBinary()) { + if (areBinaryFieldsEqual(o1, o2)) { + return true; + } + } else if (o1.equals(o2)) { + return true; + } + + appendNotEqual(data, sb, o1, o2, "o1", "o2"); + return false; + } + + private boolean areEqual( + ThriftMetadata.ThriftEnum data, + Object o1, + Object o2, + StringBuilder sb) { + + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + if (o1.equals(o2)) { + return true; + } + + appendNotEqual(data, sb, o1, o2, "o1", "o2"); + return false; + } + + private boolean areEqual( + ThriftMetadata.ThriftList data, + Object o1, + Object o2, + StringBuilder sb) { + + List<Object> l1 = (List<Object>) o1; + List<Object> l2 = (List<Object>) o2; + + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + if (!checkSizeEquality(data, l1, l2, sb, "list")) { + return false; + } + + for (int i = 0; i < l1.size(); i++) { + Object e1 = l1.get(i); + Object e2 = l2.get(i); + if (!this.areEqual(data.elementData, e1, e2, sb)) { + return false; + } + } + + return true; + } + + private boolean areEqual( + ThriftMetadata.ThriftSet data, + Object o1, + Object o2, + StringBuilder sb) { + + Set<Object> s1 = (Set<Object>) o1; + Set<Object> s2 = (Set<Object>) o2; + + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + if (!checkSizeEquality(data, s1, s2, sb, "set")) { + return false; + } + + for (Object e1 : s1) { + if (!s2.contains(e1)) { + appendResult(data, sb, "Element %s in s1 not found in s2", e1); + return false; + } + } + + return true; + } + + private boolean areEqual( + ThriftMetadata.ThriftMap data, + Object o1, + Object o2, + StringBuilder sb) { + + Map<Object, Object> m1 = (Map<Object, Object>) o1; + Map<Object, Object> m2 = (Map<Object, Object>) o2; + + ComparisonResult result = checkNullEquality(data, o1, o2, sb); + if (result != ComparisonResult.UNKNOWN) { + return result == ComparisonResult.EQUAL; + } + + if (!checkSizeEquality(data, m1.keySet(), m2.keySet(), sb, "map.keySet")) { + return false; + } + + for (Object k1 : m1.keySet()) { + if (!m2.containsKey(k1)) { + appendResult(data, sb, "Key %s in m1 not found in m2", k1); + return false; + } + + Object v1 = m1.get(k1); + Object v2 = m2.get(k1); + if (!this.areEqual(data.valueData, v1, v2, sb)) { + return false; + } + } + + return true; + } + + private boolean areBinaryFieldsEqual(Object o1, Object o2) { + if (o1 instanceof byte[]) { + if (Arrays.equals((byte[]) o1, (byte[]) o2)) { + return true; + } + } else if (o1 instanceof ByteBuffer) { + if (((ByteBuffer) o1).compareTo((ByteBuffer) o2) == 0) { + return true; + } + } else { + throw new UnsupportedOperationException( + String.format("Unsupported binary field type: %s", o1.getClass().getName())); + } + + return false; + } + + private void appendResult( + ThriftMetadata.ThriftObject data, + StringBuilder sb, + String format, + Object... args) { + if (sb != null) { + String msg = String.format(format, args); + sb.append(data.fieldId.getFieldName()); + sb.append(" : "); + sb.append(msg); + } + } + + private void appendNotEqual( + ThriftMetadata.ThriftObject data, + StringBuilder sb, + Object o1, + Object o2, + String o1name, + String o2name) { + + String o1s = o1.toString(); + String o2s = o2.toString(); + + if ((o1s.length() + o2s.length()) < 100) { + appendResult(data, sb, "%s (%s) != %s (%s)", o1name, o1s, o2name, o2s); + } else { + appendResult( + data, sb, "%s != %s\n%s =\n%s\n%s =\n%s\n", + o1name, o2name, o1name, o1s, o2name, o2s); + } + } + + private ComparisonResult checkNullEquality( + ThriftMetadata.ThriftObject data, + Object o1, + Object o2, + StringBuilder sb) { + if ((o1 == null) && (o2 == null)) { + return ComparisonResult.EQUAL; + } + + if (o1 == null) { + appendResult(data, sb, "o1 (null) != o2"); + } + + if (o2 == null) { + appendResult(data, sb, "o1 != o2 (null)"); + } + + return ComparisonResult.UNKNOWN; + } + + private boolean checkSizeEquality( + ThriftMetadata.ThriftObject data, + Collection c1, + Collection c2, + StringBuilder sb, + String typeName) { + + if (c1.size() != c2.size()) { + appendResult( + data, sb, "%s1.size(%d) != %s2.size(%d)", + typeName, c1.size(), typeName, c2.size()); + return false; + } + + return true; + } + + static UnsupportedOperationException unsupportedFieldTypeException(byte fieldType) { + return new UnsupportedOperationException("field type not supported: '" + fieldType + "'"); + } +} diff --git a/lib/java/src/org/apache/thrift/partial/README.md b/lib/java/src/org/apache/thrift/partial/README.md new file mode 100644 index 000000000..d5794fae7 --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/README.md @@ -0,0 +1,112 @@ +# Partial Thrift Deserialization + +## Overview +This document describes how partial deserialization of Thrift works. There are two main goals of this documentation: +1. Make it easier to understand the current Java implementation in this folder. +1. Be useful in implementing partial deserialization support in additional languages. + +This document is divided into two high level areas. The first part explains important concepts relevant to partial deserialization. The second part describes components involved in the Java implementation in this folder. + +Moreover, this blog provides some performance numbers and addtional information: https://medium.com/pinterest-engineering/improving-data-processing-efficiency-using-partial-deserialization-of-thrift-16bc3a4a38b4 + +## Basic Concepts + +### Motivation + +The main motivation behind implementing this feature is to improve performance when we need to access only a subset of fields in any Thrift object. This situation arises often when big data is stored in Thrift encoded format (for example, SequenceFile with serialized Thrift values). Many data processing jobs may access this data. However, not every job needs to access every field of each object. In such cases, if we have prior knowledge of the fields needed for a given job, we can deserialize only that subset of fields and avoid the cost deserializing the rest of the fields. There are two benefits of this approach: we save cpu cycles by not deserializing unnecessary field and we end up reducing gc pressure. Both of the savings quickly add up when processing billions of instances in a data processing job. + +### Partial deserialization + +Partial deserialization involves deserializing only a subset of the fields of a serialized Thrift object while efficiently skipping over the rest. One very important benefit of partial deserialization is that the output of the deserialization process is not limited to a `TBase` derived object. It can deserialize a serialized blob into any type by using an appropriate `ThriftFieldValueProcessor`. + +### Defining the subset of fields to deserialize + +The subset of fields to deserialize is defined using a list of fully qualified field names. For example, consider the Thrift `struct` definition below: + +```Thrift +struct SmallStruct { + 1: optional string stringValue; + 2: optional i16 i16Value; +} + +struct TestStruct { + 1: optional i16 i16Field; + 2: optional list<SmallStruct> structList; + 3: optional set<SmallStruct> structSet; + 4: optional map<string, SmallStruct> structMap; + 5: optional SmallStruct structField; +} +``` + +For the Thrift `struct`, each of the following line shows a fully qualified field definition. Partial deserialization uses a non-empty set of such field definitions to identify the subset of fields to deserialize. + +``` +- i16Field +- structList.stringValue +- structSet.i16Value +- structMap.stringValue +- structField.i16Value +``` + +Note that the syntax of denoting paths involving map fields do not support a way to define sub-fields of the key type. + +For example, the field path `structMap.stringValue` shown above has leaf segment `stringValue` which is a field in map values. + +## Components + +The process of partial deserialization involves the following major components. We have listed names of the Java file(s) implementing each component for easier mapping to the source code. + +### Thrift Metadata + +Source files: +- ThriftField.java +- ThriftMetadata.java + +We saw in the previous section how we can identify the subset of fields to deserialize. As the first step, we need to compile the collection of field definitions into an efficient data structure that we can traverse at runtime. This step is achieved using `ThriftField` and `ThriftMetadata` classes. For example, + +```Java +// First, create a collection of fully qualified field names. +List<String> fieldNames = Arrays.asList("i16Field", "structField.i16Value"); + +// Convert the flat collection into an n-ary tree of fields. +List<ThriftField> fields = ThriftField.fromNames(fieldNames); + +// Compile the tree of fields into internally used metadata. +ThriftMetadata.ThriftStruct metadata = + ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, fields); +``` + +At this point, we have an efficient internal representation of the fields that need to get deserialized. + +### Partial Thrift Protocol + +Source files: +- PartialThriftProtocol.java +- PartialThriftBinaryProtocol.java +- PartialThriftCompactProtocol.java + +This component implements efficient skipping over fields that need not be deserialized. Note that this skipping is more efficient compared to that achieved by using `TProtocolUtil.skip()`. The latter calls the corresponding `read()`, allocates and initializes certain values (for example, strings) and then discards the returned value. In comparison, `PartialThriftProtocol` skips a field by incrementing internal offset into the transport buffer. + +### Partial Thrift Deserializer + +Source files: +- PartialThriftDeserializer.java + +This component, traverses a serialized blob sequentially one field at a time. At the beginning of each field, it consults the informations stored in `ThriftMetadata` to see if that field needs to be deserialized. If yes, then the field is deserialized into a value as would normally take place during regular deserialization process. If that field is not in the target subset then the deserializer calls `PartialThriftProtocol` to efficiently skip over that field. + +### Field Value Processor + +Source files: +- ThriftFieldValueProcessor.java +- ThriftStructProcessor.java + +One very important benefit of partial deserialization is that the output of the deserialization process is not limited to a `TBase` derived object. It can deserialize a serialized blob into any type by using an appropriate `ThriftFieldValueProcessor`. + +When the partial Thrift deserializer deserializes a field, it passes its value to a `ThriftFieldValueProcessor`. The processor gets to decide whether the value is stored as-is or is stored in some intermediate form. The default implementation of this interface is `ThriftStructProcessor`. This implementation outputs a `TBase` derived object. There are other implementations that exist (not included in this drop at present). For example, one implementation enables deserializing a Thrift blob directly into an `InternalRow` used by `Spark`. That has yielded orders of magnitude performance improvement over a `Spark` engine that consumes `Thrift` data using its default deserializer. + +### Miscellanious Helpers + +Files: +- TFieldData.java : Holds the type and id members of a TField into a single int. This encoding scheme obviates the need to instantiate TField during the partial deserialization process. +- EnumCache.java : Provides a memoized way to lookup an enum by its value. +- PartialThriftComparer.java : Enables comparison of two TBase instances such that the comparison is limited to the subset of fields defined by the supplied metadata. diff --git a/lib/java/src/org/apache/thrift/partial/TFieldData.java b/lib/java/src/org/apache/thrift/partial/TFieldData.java new file mode 100644 index 000000000..9ba1a17ce --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/TFieldData.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +/** + * Holds the type and id members of a {@link org.apache.thrift.protocol.TField} into a single int. + * + * This encoding scheme obviates the need to instantiate TField + * during the partial deserialization process. + */ +public class TFieldData { + public static int encode(byte type) { + return (int) (type & 0xff); + } + + public static int encode(byte type, short id) { + return (type & 0xff) | (((int) id) << 8); + } + + public static byte getType(int data) { + return (byte) (0xff & data); + } + + public static short getId(int data) { + return (short) ((0xffff00 & data) >> 8); + } +} diff --git a/lib/java/src/org/apache/thrift/partial/ThriftField.java b/lib/java/src/org/apache/thrift/partial/ThriftField.java new file mode 100644 index 000000000..1b5a08c80 --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/ThriftField.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.partial.Validate; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Holds name of a thrift field and of its sub-fields recursively. + * <p> + * This class is meant to be used in conjunction with {@code TDeserializer}. + */ +public class ThriftField { + + /** + * Name of this field as it appears in a thrift file. Case sensitive. + */ + public final String name; + + /** + * List of sub-fields of this field. + * + * This list should have only those sub-fields that need to be deserialized + * by the {@code TDeserializer}. + */ + public final List<ThriftField> fields; + + /** + * Constructs a {@link ThriftField}. + * + * @param name the name of this field as it appears in a thrift file. Case sensitive. + * @param fields List of sub-fields of this field. + */ + ThriftField(String name, List<ThriftField> fields) { + Validate.checkNotNullAndNotEmpty(name, "name"); + Validate.checkNotNull(fields, "fields"); + + this.name = name; + this.fields = Collections.unmodifiableList(fields); + } + + /** + * Constructs a {@link ThriftField} that does not have any sub-fields. + */ + ThriftField(String name) { + this(name, Collections.emptyList()); + } + + // Internal-only constructor that does not mark fields as read-only. + // That allows fromNames() to construct fields from names. + // The actual value of allowFieldAdds is ignored. + // It is used only for generating a different function signature. + ThriftField(String name, List<ThriftField> fields, boolean allowFieldAdds) { + Validate.checkNotNullAndNotEmpty(name, "name"); + Validate.checkNotNull(fields, "fields"); + + this.name = name; + this.fields = fields; + } + + private int hashcode = 0; + + @Override + public int hashCode() { + if (this.hashcode == 0) { + int hc = this.name.toLowerCase().hashCode(); + for (ThriftField subField : this.fields) { + hc ^= subField.hashCode(); + } + + this.hashcode = hc; + } + + return this.hashcode; + } + + @Override + public boolean equals(Object o) { + if (o == null) { + return false; + } + + if (!(o instanceof ThriftField)) { + return false; + } + + ThriftField other = (ThriftField) o; + + if (!this.name.equalsIgnoreCase(other.name)) { + return false; + } + + if (this.fields.size() != other.fields.size()) { + return false; + } + + for (int i = 0; i < this.fields.size(); i++) { + if (!this.fields.get(i).equals(other.fields.get(i))) { + return false; + } + } + + return true; + } + + @Override + public String toString() { + return String.join(", ", this.getFieldNames()); + } + + public List<String> getFieldNames() { + List<String> fieldsList = new ArrayList<>(); + if (this.fields.size() == 0) { + fieldsList.add(this.name); + } else { + for (ThriftField f : this.fields) { + for (String subF : f.getFieldNames()) { + fieldsList.add(this.name + "." + subF); + } + } + } + + return fieldsList; + } + + /** + * Generates and returns n-ary tree of fields and their sub-fields. + * <p> + * @param fieldNames collection of fully qualified field names. + * + * for example, + * In case of PinJoin thrift struct, the following are valid field names + * -- signature + * -- pins.user.userId + * -- textSignal.termSignal.termDataMap + * + * @return n-ary tree of fields and their sub-fields. + */ + public static List<ThriftField> fromNames(Collection<String> fieldNames) { + Validate.checkNotNullAndNotEmpty(fieldNames, "fieldNames"); + + List<String> fieldNamesList = new ArrayList<>(fieldNames); + Collections.sort(fieldNamesList, String.CASE_INSENSITIVE_ORDER); + + List<ThriftField> fields = new ArrayList<>(); + + for (String fieldName : fieldNamesList) { + List<ThriftField> tfields = fields; + String[] tokens = fieldName.split("\\."); + + for (String token : tokens) { + ThriftField field = findField(token, tfields); + if (field == null) { + field = new ThriftField(token, new ArrayList<>(), true); + tfields.add(field); + } + tfields = field.fields; + } + } + + return makeReadOnly(fields); + } + + private static ThriftField findField(String name, List<ThriftField> fields) { + for (ThriftField field : fields) { + if (field.name.equalsIgnoreCase(name)) { + return field; + } + } + return null; + } + + private static List<ThriftField> makeReadOnly(List<ThriftField> fields) { + List<ThriftField> result = new ArrayList<>(fields.size()); + for (ThriftField field : fields) { + ThriftField copy = new ThriftField(field.name, makeReadOnly(field.fields)); + result.add(copy); + } + return Collections.unmodifiableList(result); + } +} diff --git a/lib/java/src/org/apache/thrift/partial/ThriftFieldValueProcessor.java b/lib/java/src/org/apache/thrift/partial/ThriftFieldValueProcessor.java new file mode 100644 index 000000000..33982d1d2 --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/ThriftFieldValueProcessor.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.TEnum; +import org.apache.thrift.TFieldIdEnum; + +import java.nio.ByteBuffer; + +/** + * Provides an abstraction to process deserialized field values and place them + * into the collection that holds them. This abstraction allows different types + * of collections to be output from partial deserialization. + * + * In case of the usual Thrift deserialization, the collection that holds field + * values is simply an instance of TBase. + */ +public interface ThriftFieldValueProcessor<V> { + + // Struct related methods; + Object createNewStruct(ThriftMetadata.ThriftStruct metadata); + + V prepareStruct(Object instance); + + void setBool(V valueCollection, TFieldIdEnum fieldId, boolean value); + + void setByte(V valueCollection, TFieldIdEnum fieldId, byte value); + + void setInt16(V valueCollection, TFieldIdEnum fieldId, short value); + + void setInt32(V valueCollection, TFieldIdEnum fieldId, int value); + + void setInt64(V valueCollection, TFieldIdEnum fieldId, long value); + + void setDouble(V valueCollection, TFieldIdEnum fieldId, double value); + + void setBinary(V valueCollection, TFieldIdEnum fieldId, ByteBuffer value); + + void setString(V valueCollection, TFieldIdEnum fieldId, ByteBuffer buffer); + + void setEnumField(V valueCollection, TFieldIdEnum fieldId, Object value); + + void setListField(V valueCollection, TFieldIdEnum fieldId, Object value); + + void setMapField(V valueCollection, TFieldIdEnum fieldId, Object value); + + void setSetField(V valueCollection, TFieldIdEnum fieldId, Object value); + + void setStructField(V valueCollection, TFieldIdEnum fieldId, Object value); + + Object prepareEnum(Class<? extends TEnum> enumClass, int ordinal); + + Object prepareString(ByteBuffer buffer); + + Object prepareBinary(ByteBuffer buffer); + + // List field related methods. + Object createNewList(int expectedSize); + + void setListElement(Object instance, int index, Object value); + + Object prepareList(Object instance); + + // Map field related methods. + Object createNewMap(int expectedSize); + + void setMapElement(Object instance, int index, Object key, Object value); + + Object prepareMap(Object instance); + + // Set field related methods. + Object createNewSet(int expectedSize); + + void setSetElement(Object instance, int index, Object value); + + Object prepareSet(Object instance); +} diff --git a/lib/java/src/org/apache/thrift/partial/ThriftMetadata.java b/lib/java/src/org/apache/thrift/partial/ThriftMetadata.java new file mode 100644 index 000000000..984d97249 --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/ThriftMetadata.java @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.commons.lang3.StringUtils; +import org.apache.thrift.TBase; +import org.apache.thrift.TFieldIdEnum; +import org.apache.thrift.TFieldRequirementType; +import org.apache.thrift.TUnion; +import org.apache.thrift.meta_data.FieldMetaData; +import org.apache.thrift.meta_data.FieldValueMetaData; +import org.apache.thrift.meta_data.ListMetaData; +import org.apache.thrift.meta_data.MapMetaData; +import org.apache.thrift.meta_data.SetMetaData; +import org.apache.thrift.meta_data.StructMetaData; +import org.apache.thrift.partial.Validate; +import org.apache.thrift.protocol.TType; + +import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Container for Thrift metadata classes such as {@link ThriftPrimitive}, + * {@link ThriftList}, etc. + * <p> + * This class is mainly used by {@code TDeserializer}. + */ +public class ThriftMetadata { + + enum FieldTypeEnum implements TFieldIdEnum { + ROOT((short) 0, "root"), + ENUM((short) 1, "enum"), + LIST_ELEMENT((short) 2, "listElement"), + MAP_KEY((short) 3, "mapKey"), + MAP_VALUE((short) 4, "mapValue"), + SET_ELEMENT((short) 5, "setElement"); + + private short id; + private String name; + + FieldTypeEnum(short id, String name) { + this.id = id; + this.name = name; + } + + @Override + public short getThriftFieldId() { + return id; + } + + @Override + public String getFieldName() { + return name; + } + } + + private enum ComparisonResult { + UNKNOWN, + EQUAL, + NOT_EQUAL + } + + /** + * Base class of field types that can be partially deserialized. + * + * Holds metadata necessary for partial deserialization. + * The metadata is internally computed and used; therefore it is not visible to + * the users of {@code TDeserializer}. + */ + public abstract static class ThriftObject implements Serializable { + public final ThriftObject parent; + public final TFieldIdEnum fieldId; + public final FieldMetaData data; + + // Placeholder to attach additional data. This class or its descendents + // do not try to access or interpret this field. + public Object additionalData; + + ThriftObject(ThriftObject parent, TFieldIdEnum fieldId, FieldMetaData data) { + this.parent = parent; + this.fieldId = fieldId; + this.data = data; + } + + /** + * Converts this instance to formatted and indented string representation. + * + * @param sb the {@code StringBuilder} to add formatted strings to. + * @param level the current indent level. + */ + protected abstract void toPrettyString(StringBuilder sb, int level); + + /** + * Gets a space string whose length is proportional to the given indent level. + */ + protected String getIndent(int level) { + return StringUtils.repeat(" ", level * 4); + } + + /** + * Helper method to append a formatted string to the given {@code StringBuilder}. + */ + protected void append(StringBuilder sb, String format, Object... args) { + sb.append(String.format(format, args)); + } + + /** + * Gets the name of this field. + */ + protected String getName() { + return this.fieldId.getFieldName(); + } + + protected List<String> noFields = Collections.emptyList(); + + protected String getSubElementName(TFieldIdEnum fieldId) { + return getSubElementName(fieldId, "element"); + } + + protected String getSubElementName(TFieldIdEnum fieldId, String suffix) { + return String.format("%s_%s", fieldId.getFieldName(), suffix); + } + + private static class Factory { + + static ThriftObject createNew( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + List<ThriftField> fields) { + + byte fieldType = data.valueMetaData.type; + switch (fieldType) { + case TType.STRUCT: + return ThriftStructBase.create(parent, fieldId, data, fields); + + case TType.LIST: + return new ThriftList(parent, fieldId, data, fields); + + case TType.MAP: + return new ThriftMap(parent, fieldId, data, fields); + + case TType.SET: + return new ThriftSet(parent, fieldId, data, fields); + + case TType.ENUM: + return new ThriftEnum(parent, fieldId, data); + + case TType.BOOL: + case TType.BYTE: + case TType.I16: + case TType.I32: + case TType.I64: + case TType.DOUBLE: + case TType.STRING: + return new ThriftPrimitive(parent, fieldId, data); + + default: + throw unsupportedFieldTypeException(fieldType); + } + } + } + } + + /** + * Metadata about primitive types. + */ + public static class ThriftPrimitive extends ThriftObject { + ThriftPrimitive(ThriftObject parent, TFieldIdEnum fieldId, FieldMetaData data) { + super(parent, fieldId, data); + } + + public boolean isBinary() { + return this.data.valueMetaData.isBinary(); + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + String fieldType = this.getTypeName(); + this.append(sb, "%s%s %s;\n", this.getIndent(level), fieldType, this.getName()); + } + + private String getTypeName() { + byte fieldType = this.data.valueMetaData.type; + switch (fieldType) { + case TType.BOOL: + return "bool"; + + case TType.BYTE: + return "byte"; + + case TType.I16: + return "i16"; + + case TType.I32: + return "i32"; + + case TType.I64: + return "i64"; + + case TType.DOUBLE: + return "double"; + + case TType.STRING: + if (this.isBinary()) { + return "binary"; + } else { + return "string"; + } + + default: + throw unsupportedFieldTypeException(fieldType); + } + } + + private ThriftStruct getParentStruct() { + ThriftObject tparent = parent; + while (tparent != null) { + if (tparent instanceof ThriftStruct) { + return (ThriftStruct) tparent; + } + tparent = tparent.parent; + } + return null; + } + } + + public static class ThriftEnum extends ThriftObject { + private static EnumCache enums = new EnumCache(); + + ThriftEnum(ThriftObject parent, TFieldIdEnum fieldId, FieldMetaData data) { + super(parent, fieldId, data); + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + this.append(sb, "%senum %s;\n", this.getIndent(level), this.getName()); + } + } + + /** + * Metadata of container like objects: list, set, map + */ + public abstract static class ThriftContainer extends ThriftObject { + + public ThriftContainer( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data) { + super(parent, fieldId, data); + } + + public abstract boolean hasUnion(); + } + + public static class ThriftList extends ThriftContainer { + public final ThriftObject elementData; + + ThriftList( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + List<ThriftField> fields) { + super(parent, fieldId, data); + + this.elementData = ThriftObject.Factory.createNew( + this, + FieldTypeEnum.LIST_ELEMENT, + new FieldMetaData( + getSubElementName(fieldId), + TFieldRequirementType.REQUIRED, + ((ListMetaData) data.valueMetaData).elemMetaData), + fields); + } + + @Override + public boolean hasUnion() { + return this.elementData instanceof ThriftUnion; + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + this.append(sb, "%slist<\n", this.getIndent(level)); + this.elementData.toPrettyString(sb, level + 1); + this.append(sb, "%s> %s;\n", this.getIndent(level), this.getName()); + } + } + + public static class ThriftSet extends ThriftContainer { + public final ThriftObject elementData; + + ThriftSet( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + List<ThriftField> fields) { + super(parent, fieldId, data); + + this.elementData = ThriftObject.Factory.createNew( + this, + FieldTypeEnum.SET_ELEMENT, + new FieldMetaData( + getSubElementName(fieldId), + TFieldRequirementType.REQUIRED, + ((SetMetaData) data.valueMetaData).elemMetaData), + fields); + } + + @Override + public boolean hasUnion() { + return this.elementData instanceof ThriftUnion; + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + this.append(sb, "%sset<\n", this.getIndent(level)); + this.elementData.toPrettyString(sb, level + 1); + this.append(sb, "%s> %s;\n", this.getIndent(level), this.getName()); + } + } + + public static class ThriftMap extends ThriftContainer { + public final ThriftObject keyData; + public final ThriftObject valueData; + + ThriftMap( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + List<ThriftField> fields) { + super(parent, fieldId, data); + + this.keyData = ThriftObject.Factory.createNew( + this, + FieldTypeEnum.MAP_KEY, + new FieldMetaData( + getSubElementName(fieldId, "key"), + TFieldRequirementType.REQUIRED, + ((MapMetaData) data.valueMetaData).keyMetaData), + Collections.emptyList()); + + this.valueData = ThriftObject.Factory.createNew( + this, + FieldTypeEnum.MAP_VALUE, + new FieldMetaData( + getSubElementName(fieldId, "value"), + TFieldRequirementType.REQUIRED, + ((MapMetaData) data.valueMetaData).valueMetaData), + fields); + } + + @Override + public boolean hasUnion() { + return (this.keyData instanceof ThriftUnion) || (this.valueData instanceof ThriftUnion); + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + this.append(sb, "%smap<\n", this.getIndent(level)); + this.append(sb, "%skey = {\n", this.getIndent(level + 1)); + this.keyData.toPrettyString(sb, level + 2); + this.append(sb, "%s},\n", this.getIndent(level + 1)); + this.append(sb, "%svalue = {\n", this.getIndent(level + 1)); + this.valueData.toPrettyString(sb, level + 2); + this.append(sb, "%s}\n", this.getIndent(level + 1)); + this.append(sb, "%s> %s;\n", this.getIndent(level), this.getName()); + } + } + + /** + * Base class for metadata of ThriftStruct and ThriftUnion. + * Holds functionality that is common to both. + */ + public abstract static class ThriftStructBase<U extends TBase> extends ThriftObject { + public ThriftStructBase( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data) { + super(parent, fieldId, data); + } + + public Class<U> getStructClass() { + return getStructClass(this.data); + } + + public static <U extends TBase> Class<U> getStructClass(FieldMetaData data) { + return (Class<U>) ((StructMetaData) data.valueMetaData).structClass; + } + + public boolean isUnion() { + return isUnion(this.data); + } + + public static boolean isUnion(FieldMetaData data) { + return TUnion.class.isAssignableFrom(getStructClass(data)); + } + + public static ThriftStructBase create( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + Iterable<ThriftField> fieldsData) { + + if (isUnion(data)) { + return new ThriftUnion(parent, fieldId, data, fieldsData); + } else { + return new ThriftStruct(parent, fieldId, data, fieldsData); + } + } + } + + /** + * Metadata of a Thrift union. + * Currently not adequately supported. + */ + public static class ThriftUnion<U extends TBase> extends ThriftStructBase { + public ThriftUnion( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + Iterable<ThriftField> fieldsData) { + super(parent, fieldId, data); + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + String indent = this.getIndent(level); + String indent2 = this.getIndent(level + 1); + this.append(sb, "%sunion %s {\n", indent, this.getName()); + this.append(sb, "%s// unions not adequately supported at present.\n", indent2); + this.append(sb, "%s}\n", indent); + } + } + + /** + * Metadata of a Thrift struct. + */ + public static class ThriftStruct<U extends TBase> extends ThriftStructBase { + public final Map<Integer, ThriftObject> fields; + + ThriftStruct( + ThriftObject parent, + TFieldIdEnum fieldId, + FieldMetaData data, + Iterable<ThriftField> fieldsData) { + super(parent, fieldId, data); + + Class<U> clasz = getStructClass(data); + this.fields = getFields(this, clasz, fieldsData); + } + + public <T extends TBase> T createNewStruct() { + T instance = null; + + try { + Class<T> structClass = getStructClass(this.data); + instance = (T) structClass.newInstance(); + } catch (InstantiationException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + + return instance; + } + + public static <T extends TBase> ThriftStruct of(Class<T> clasz) { + return ThriftStruct.fromFields(clasz, Collections.emptyList()); + } + + public static <T extends TBase> ThriftStruct fromFieldNames( + Class<T> clasz, + Collection<String> fieldNames) { + return fromFields(clasz, ThriftField.fromNames(fieldNames)); + } + + public static <T extends TBase> ThriftStruct fromFields( + Class<T> clasz, + Iterable<ThriftField> fields) { + + Validate.checkNotNull(clasz, "clasz"); + Validate.checkNotNull(fields, "fields"); + + return new ThriftStruct( + null, + FieldTypeEnum.ROOT, + new FieldMetaData( + FieldTypeEnum.ROOT.getFieldName(), + TFieldRequirementType.REQUIRED, + new StructMetaData(TType.STRUCT, clasz)), + fields); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + this.toPrettyString(sb, 0); + return sb.toString(); + } + + @Override + protected void toPrettyString(StringBuilder sb, int level) { + String indent = this.getIndent(level); + String indent2 = this.getIndent(level + 1); + this.append(sb, "%sstruct %s {\n", indent, this.getName()); + if (this.fields.size() == 0) { + this.append(sb, "%s*;", indent2); + } else { + List<Integer> ids = new ArrayList(this.fields.keySet()); + Collections.sort(ids); + for (Integer id : ids) { + this.fields.get(id).toPrettyString(sb, level + 1); + } + } + this.append(sb, "%s}\n", indent); + } + + private static <U extends TBase> Map<Integer, ThriftObject> getFields( + ThriftStruct parent, + Class<U> clasz, + Iterable<ThriftField> fieldsData) { + + Map<? extends TFieldIdEnum, FieldMetaData> fieldsMetaData = + FieldMetaData.getStructMetaDataMap(clasz); + Map<Integer, ThriftObject> fields = new HashMap(); + boolean getAllFields = !fieldsData.iterator().hasNext(); + + if (getAllFields) { + for (Map.Entry<? extends TFieldIdEnum, FieldMetaData> entry : fieldsMetaData.entrySet()) { + TFieldIdEnum fieldId = entry.getKey(); + FieldMetaData fieldMetaData = entry.getValue(); + ThriftObject field = + ThriftObject.Factory.createNew(parent, fieldId, fieldMetaData, Collections.emptyList()); + fields.put((int) fieldId.getThriftFieldId(), field); + } + } else { + for (ThriftField fieldData : fieldsData) { + String fieldName = fieldData.name; + FieldMetaData fieldMetaData = findFieldMetaData(fieldsMetaData, fieldName); + TFieldIdEnum fieldId = findFieldId(fieldsMetaData, fieldName); + ThriftObject field = + ThriftObject.Factory.createNew(parent, fieldId, fieldMetaData, fieldData.fields); + fields.put((int) fieldId.getThriftFieldId(), field); + } + } + + return fields; + } + + private static FieldMetaData findFieldMetaData( + Map<? extends TFieldIdEnum, FieldMetaData> fieldsMetaData, + String fieldName) { + + for (FieldMetaData fieldData : fieldsMetaData.values()) { + if (fieldData.fieldName.equals(fieldName)) { + return fieldData; + } + } + + throw fieldNotFoundException(fieldName); + } + + private static TFieldIdEnum findFieldId( + Map<? extends TFieldIdEnum, FieldMetaData> fieldsMetaData, + String fieldName) { + + for (TFieldIdEnum fieldId : fieldsMetaData.keySet()) { + if (fieldId.getFieldName().equals(fieldName)) { + return fieldId; + } + } + + throw fieldNotFoundException(fieldName); + } + } + + static IllegalArgumentException fieldNotFoundException(String fieldName) { + return new IllegalArgumentException("field not found: '" + fieldName + "'"); + } + + static UnsupportedOperationException unsupportedFieldTypeException(byte fieldType) { + return new UnsupportedOperationException("field type not supported: '" + fieldType + "'"); + } +} diff --git a/lib/java/src/org/apache/thrift/partial/ThriftStructProcessor.java b/lib/java/src/org/apache/thrift/partial/ThriftStructProcessor.java new file mode 100644 index 000000000..95789144d --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/ThriftStructProcessor.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.TBase; +import org.apache.thrift.TEnum; +import org.apache.thrift.TFieldIdEnum; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; + +/** + * Provides a way to create and initialize an instance of TBase during partial deserialization. + * + * This class is supposed to be used as a helper class for {@code PartialThriftDeserializer}. + */ +public class ThriftStructProcessor implements ThriftFieldValueProcessor<TBase> { + + private static final EnumCache enums = new EnumCache(); + + @Override + public Object createNewStruct(ThriftMetadata.ThriftStruct metadata) { + return metadata.createNewStruct(); + } + + @Override + public TBase prepareStruct(Object instance) { + return (TBase) instance; + } + + @Override + public Object createNewList(int expectedSize) { + return new Object[expectedSize]; + } + + @Override + public void setListElement(Object instance, int index, Object value) { + ((Object[]) instance)[index] = value; + } + + @Override + public Object prepareList(Object instance) { + return Arrays.asList((Object[]) instance); + } + + @Override + public Object createNewMap(int expectedSize) { + return new HashMap<Object, Object>(expectedSize); + } + + @Override + public void setMapElement(Object instance, int index, Object key, Object value) { + ((HashMap<Object, Object>) instance).put(key, value); + } + + @Override + public Object prepareMap(Object instance) { + return instance; + } + + @Override + public Object createNewSet(int expectedSize) { + return new HashSet<Object>(expectedSize); + } + + @Override + public void setSetElement(Object instance, int index, Object value) { + ((HashSet<Object>) instance).add(value); + } + + @Override + public Object prepareSet(Object instance) { + return instance; + } + + @Override + public Object prepareEnum(Class<? extends TEnum> enumClass, int ordinal) { + return enums.get(enumClass, ordinal); + } + + @Override + public Object prepareString(ByteBuffer buffer) { + return byteBufferToString(buffer); + } + + @Override + public Object prepareBinary(ByteBuffer buffer) { + return buffer; + } + + @Override + public void setBool(TBase valueCollection, TFieldIdEnum fieldId, boolean value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setByte(TBase valueCollection, TFieldIdEnum fieldId, byte value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setInt16(TBase valueCollection, TFieldIdEnum fieldId, short value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setInt32(TBase valueCollection, TFieldIdEnum fieldId, int value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setInt64(TBase valueCollection, TFieldIdEnum fieldId, long value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setDouble(TBase valueCollection, TFieldIdEnum fieldId, double value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setBinary(TBase valueCollection, TFieldIdEnum fieldId, ByteBuffer value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setString(TBase valueCollection, TFieldIdEnum fieldId, ByteBuffer buffer) { + String value = byteBufferToString(buffer); + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setEnumField(TBase valueCollection, TFieldIdEnum fieldId, Object value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setListField(TBase valueCollection, TFieldIdEnum fieldId, Object value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setMapField(TBase valueCollection, TFieldIdEnum fieldId, Object value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setSetField(TBase valueCollection, TFieldIdEnum fieldId, Object value) { + valueCollection.setFieldValue(fieldId, value); + } + + @Override + public void setStructField(TBase valueCollection, TFieldIdEnum fieldId, Object value) { + valueCollection.setFieldValue(fieldId, value); + } + + private static String byteBufferToString(ByteBuffer buffer) { + byte[] bytes = buffer.array(); + int pos = buffer.position(); + return new String(bytes, pos, buffer.limit() - pos, StandardCharsets.UTF_8); + } +} diff --git a/lib/java/src/org/apache/thrift/partial/Validate.java b/lib/java/src/org/apache/thrift/partial/Validate.java new file mode 100644 index 000000000..ef0466a5e --- /dev/null +++ b/lib/java/src/org/apache/thrift/partial/Validate.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collection; + +/** + * A superset of Validate class in Apache commons lang3. + * + * It provides consistent message strings for frequently encountered checks. + * That simplifies callers because they have to supply only the name of the argument + * that failed a check instead of having to supply the entire message. + */ +public final class Validate { + private Validate() {} + + /** + * Validates that the given reference argument is not null. + */ + public static void checkNotNull(Object obj, String argName) { + checkArgument(obj != null, "'%s' must not be null.", argName); + } + + /** + * Validates that the given integer argument is not zero or negative. + */ + public static void checkPositiveInteger(long value, String argName) { + checkArgument(value > 0, "'%s' must be a positive integer.", argName); + } + + /** + * Validates that the given integer argument is not negative. + */ + public static void checkNotNegative(long value, String argName) { + checkArgument(value >= 0, "'%s' must not be negative.", argName); + } + + /* + * Validates that the expression (that checks a required field is present) is true. + */ + public static void checkRequired(boolean isPresent, String argName) { + checkArgument(isPresent, "'%s' is required.", argName); + } + + /** + * Validates that the expression (that checks a field is valid) is true. + */ + public static void checkValid(boolean isValid, String argName) { + checkArgument(isValid, "'%s' is invalid.", argName); + } + + /** + * Validates that the expression (that checks a field is valid) is true. + */ + public static void checkValid(boolean isValid, String argName, String validValues) { + checkArgument(isValid, "'%s' is invalid. Valid values are: %s.", argName, validValues); + } + + /** + * Validates that the given string is not null and has non-zero length. + */ + public static void checkNotNullAndNotEmpty(String arg, String argName) { + Validate.checkNotNull(arg, argName); + Validate.checkArgument( + arg.length() > 0, + "'%s' must not be empty.", + argName); + } + + /** + * Validates that the given array is not null and has at least one element. + */ + public static <T> void checkNotNullAndNotEmpty(T[] array, String argName) { + Validate.checkNotNull(array, argName); + checkNotEmpty(array.length, argName); + } + + /** + * Validates that the given array is not null and has at least one element. + */ + public static void checkNotNullAndNotEmpty(byte[] array, String argName) { + Validate.checkNotNull(array, argName); + checkNotEmpty(array.length, argName); + } + + /** + * Validates that the given array is not null and has at least one element. + */ + public static void checkNotNullAndNotEmpty(short[] array, String argName) { + Validate.checkNotNull(array, argName); + checkNotEmpty(array.length, argName); + } + + /** + * Validates that the given array is not null and has at least one element. + */ + public static void checkNotNullAndNotEmpty(int[] array, String argName) { + Validate.checkNotNull(array, argName); + checkNotEmpty(array.length, argName); + } + + /** + * Validates that the given array is not null and has at least one element. + */ + public static void checkNotNullAndNotEmpty(long[] array, String argName) { + Validate.checkNotNull(array, argName); + checkNotEmpty(array.length, argName); + } + + /** + * Validates that the given buffer is not null and has non-zero capacity. + */ + public static <T> void checkNotNullAndNotEmpty(Iterable<T> iter, String argName) { + Validate.checkNotNull(iter, argName); + int minNumElements = iter.iterator().hasNext() ? 1 : 0; + checkNotEmpty(minNumElements, argName); + } + + /** + * Validates that the given set is not null and has an exact number of items. + */ + public static <T> void checkNotNullAndNumberOfElements( + Collection<T> collection, int numElements, String argName) { + Validate.checkNotNull(collection, argName); + checkArgument( + collection.size() == numElements, + "Number of elements in '%s' must be exactly %s, %s given.", + argName, + numElements, + collection.size() + ); + } + + /** + * Validates that the given two values are equal. + */ + public static void checkValuesEqual( + long value1, + String value1Name, + long value2, + String value2Name) { + checkArgument( + value1 == value2, + "'%s' (%s) must equal '%s' (%s).", + value1Name, + value1, + value2Name, + value2); + } + + /** + * Validates that the first value is an integer multiple of the second value. + */ + public static void checkIntegerMultiple( + long value1, + String value1Name, + long value2, + String value2Name) { + checkArgument( + (value1 % value2) == 0, + "'%s' (%s) must be an integer multiple of '%s' (%s).", + value1Name, + value1, + value2Name, + value2); + } + + /** + * Validates that the first value is greater than the second value. + */ + public static void checkGreater( + long value1, + String value1Name, + long value2, + String value2Name) { + checkArgument( + value1 > value2, + "'%s' (%s) must be greater than '%s' (%s).", + value1Name, + value1, + value2Name, + value2); + } + + /** + * Validates that the first value is greater than or equal to the second value. + */ + public static void checkGreaterOrEqual( + long value1, + String value1Name, + long value2, + String value2Name) { + checkArgument( + value1 >= value2, + "'%s' (%s) must be greater than or equal to '%s' (%s).", + value1Name, + value1, + value2Name, + value2); + } + + /** + * Validates that the first value is less than or equal to the second value. + */ + public static void checkLessOrEqual( + long value1, + String value1Name, + long value2, + String value2Name) { + checkArgument( + value1 <= value2, + "'%s' (%s) must be less than or equal to '%s' (%s).", + value1Name, + value1, + value2Name, + value2); + } + + /** + * Validates that the given value is within the given range of values. + */ + public static void checkWithinRange( + long value, + String valueName, + long minValueInclusive, + long maxValueInclusive) { + checkArgument( + (value >= minValueInclusive) && (value <= maxValueInclusive), + "'%s' (%s) must be within the range [%s, %s].", + valueName, + value, + minValueInclusive, + maxValueInclusive); + } + + /** + * Validates that the given value is within the given range of values. + */ + public static void checkWithinRange( + double value, + String valueName, + double minValueInclusive, + double maxValueInclusive) { + checkArgument( + (value >= minValueInclusive) && (value <= maxValueInclusive), + "'%s' (%s) must be within the range [%s, %s].", + valueName, + value, + minValueInclusive, + maxValueInclusive); + } + + public static void checkPathExists(Path path, String argName) { + checkNotNull(path, argName); + checkArgument(Files.exists(path), "Path %s (%s) does not exist.", argName, path); + } + + public static void checkPathExistsAsDir(Path path, String argName) { + checkPathExists(path, argName); + checkArgument( + Files.isDirectory(path), + "Path %s (%s) must point to a directory.", + argName, + path); + } + + public static void checkPathExistsAsFile(Path path, String argName) { + checkPathExists(path, argName); + checkArgument(Files.isRegularFile(path), "Path %s (%s) must point to a file.", argName, path); + } + + public static void checkArgument(boolean expression, String format, Object... args) { + org.apache.commons.lang3.Validate.isTrue(expression, format, args); + } + + public static void checkState(boolean expression, String format, Object... args) { + org.apache.commons.lang3.Validate.validState(expression, format, args); + } + + private static void checkNotEmpty(int arraySize, String argName) { + Validate.checkArgument( + arraySize > 0, + "'%s' must have at least one element.", + argName); + } +} diff --git a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java index fc46f7c6f..e8444fefb 100644 --- a/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java +++ b/lib/java/src/org/apache/thrift/protocol/TBinaryProtocol.java @@ -23,6 +23,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import org.apache.thrift.TException; +import org.apache.thrift.partial.TFieldData; import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; @@ -486,4 +487,53 @@ public class TBinaryProtocol extends TProtocol { default: throw new TTransportException(TTransportException.UNKNOWN, "unrecognized type code"); } } + // ----------------------------------------------------------------- + // Additional methods to improve performance. + + @Override + public int readFieldBeginData() throws TException { + byte type = this.readByte(); + if (type == TType.STOP) { + return TFieldData.encode(type); + } + + short id = this.readI16(); + return TFieldData.encode(type, id); + } + + @Override + protected void skipBool() throws TException { + this.skipBytes(1); + } + + @Override + protected void skipByte() throws TException { + this.skipBytes(1); + } + + @Override + protected void skipI16() throws TException { + this.skipBytes(2); + } + + @Override + protected void skipI32() throws TException { + this.skipBytes(4); + } + + @Override + protected void skipI64() throws TException { + this.skipBytes(8); + } + + @Override + protected void skipDouble() throws TException { + this.skipBytes(8); + } + + @Override + protected void skipBinary() throws TException { + int size = readI32(); + this.skipBytes(size); + } } diff --git a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java index 4f4e21f50..832e197dd 100644 --- a/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java +++ b/lib/java/src/org/apache/thrift/protocol/TCompactProtocol.java @@ -728,7 +728,7 @@ public class TCompactProtocol extends TProtocol { } getTransport().checkReadBytesAvailable(length); - + if (stringLengthLimit_ != NO_LENGTH_LIMIT && length > stringLengthLimit_) { throw new TProtocolException(TProtocolException.SIZE_LIMIT, "Length exceeded max allowed: " + length); @@ -942,4 +942,12 @@ public class TCompactProtocol extends TProtocol { throw new TTransportException(TTransportException.UNKNOWN, "unrecognized type code"); } } + // ----------------------------------------------------------------- + // Additional methods to improve performance. + + @Override + protected void skipBinary() throws TException { + int size = intToZigZag(readI32()); + this.skipBytes(size); + } } diff --git a/lib/java/src/org/apache/thrift/protocol/TProtocol.java b/lib/java/src/org/apache/thrift/protocol/TProtocol.java index 38c030e73..3589b64e3 100644 --- a/lib/java/src/org/apache/thrift/protocol/TProtocol.java +++ b/lib/java/src/org/apache/thrift/protocol/TProtocol.java @@ -22,6 +22,7 @@ package org.apache.thrift.protocol; import java.nio.ByteBuffer; import org.apache.thrift.TException; +import org.apache.thrift.partial.TFieldData; import org.apache.thrift.scheme.IScheme; import org.apache.thrift.scheme.StandardScheme; import org.apache.thrift.transport.TTransport; @@ -180,4 +181,150 @@ public abstract class TProtocol { public Class<? extends IScheme> getScheme() { return StandardScheme.class; } + + // ----------------------------------------------------------------- + // Additional methods to improve performance. + + public int readFieldBeginData() throws TException { + // Derived classes should provide a more efficient version of this + // method if allowed by the encoding used by that protocol. + TField tfield = this.readFieldBegin(); + return TFieldData.encode(tfield.type, tfield.id); + } + + public void skip(byte fieldType) throws TException { + this.skip(fieldType, Integer.MAX_VALUE); + } + + public void skip(byte fieldType, int maxDepth) throws TException { + if (maxDepth <= 0) { + throw new TException("Maximum skip depth exceeded"); + } + + switch (fieldType) { + case TType.BOOL: + this.skipBool(); + break; + + case TType.BYTE: + this.skipByte(); + break; + + case TType.I16: + this.skipI16(); + break; + + case TType.I32: + this.skipI32(); + break; + + case TType.I64: + this.skipI64(); + break; + + case TType.DOUBLE: + this.skipDouble(); + break; + + case TType.STRING: + this.skipBinary(); + break; + + case TType.STRUCT: + this.readStructBegin(); + while (true) { + int tfieldData = this.readFieldBeginData(); + byte tfieldType = TFieldData.getType(tfieldData); + if (tfieldType == TType.STOP) { + break; + } + this.skip(tfieldType, maxDepth - 1); + this.readFieldEnd(); + } + this.readStructEnd(); + break; + + case TType.MAP: + TMap map = this.readMapBegin(); + for (int i = 0; i < map.size; i++) { + this.skip(map.keyType, maxDepth - 1); + this.skip(map.valueType, maxDepth - 1); + } + this.readMapEnd(); + break; + + case TType.SET: + TSet set = this.readSetBegin(); + for (int i = 0; i < set.size; i++) { + this.skip(set.elemType, maxDepth - 1); + } + this.readSetEnd(); + break; + + case TType.LIST: + TList list = this.readListBegin(); + for (int i = 0; i < list.size; i++) { + this.skip(list.elemType, maxDepth - 1); + } + this.readListEnd(); + break; + + default: + throw new TProtocolException( + TProtocolException.INVALID_DATA, "Unrecognized type " + fieldType); + } + } + + /** + * The default implementation of all skip() methods calls the corresponding read() method. + * Protocols that derive from this class are strongly encouraged to provide + * a more efficient alternative. + */ + + protected void skipBool() throws TException { + this.readBool(); + } + + protected void skipByte() throws TException { + this.readByte(); + } + + protected void skipI16() throws TException { + this.readI16(); + } + + protected void skipI32() throws TException { + this.readI32(); + } + + protected void skipI64() throws TException { + this.readI64(); + } + + protected void skipDouble() throws TException { + this.readDouble(); + } + + protected void skipBinary() throws TException { + this.readBinary(); + } + + static final int MAX_SKIPPED_BYTES = 256; + protected byte[] skippedBytes = new byte[MAX_SKIPPED_BYTES]; + + protected void skipBytes(int numBytes) throws TException { + if (numBytes <= MAX_SKIPPED_BYTES) { + if (this.getTransport().getBytesRemainingInBuffer() >= numBytes) { + this.getTransport().consumeBuffer(numBytes); + } else { + this.getTransport().readAll(skippedBytes, 0, numBytes); + } + } else { + int remaining = numBytes; + while (remaining > 0) { + skipBytes(Math.min(remaining, MAX_SKIPPED_BYTES)); + remaining -= MAX_SKIPPED_BYTES; + } + } + } } diff --git a/lib/java/test/org/apache/thrift/TestPartialThriftDeserializer.java b/lib/java/test/org/apache/thrift/TestPartialThriftDeserializer.java new file mode 100644 index 000000000..c0c7b892d --- /dev/null +++ b/lib/java/test/org/apache/thrift/TestPartialThriftDeserializer.java @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.*; + +import org.apache.thrift.partial.TestStruct; +import org.apache.thrift.partial.ThriftField; +import org.apache.thrift.partial.TstEnum; +import org.apache.thrift.partial.ExceptionAsserts; + +import org.apache.thrift.TBase; +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TException; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class TestPartialThriftDeserializer { + + private ThriftSerDe serde = new ThriftSerDe(); + private TBinaryProtocol.Factory binaryProtocolFactory = new TBinaryProtocol.Factory(); + private TCompactProtocol.Factory compactProtocolFactory = new TCompactProtocol.Factory(); + + private PartialThriftTestData testData = new PartialThriftTestData(); + + public TestPartialThriftDeserializer() throws TException { + } + + @Test + public void testArgChecks() throws TException { + // Should not throw. + List<String> fieldNames = Arrays.asList("i32Field"); + new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory); + + // Verify it throws correctly. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'thriftClass' must not be null", + () -> new TDeserializer(null, fieldNames, binaryProtocolFactory)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'fieldNames' must not be null", + () -> new TDeserializer(TestStruct.class, null, binaryProtocolFactory)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'processor' must not be null", + () -> new TDeserializer(TestStruct.class, fieldNames, null, binaryProtocolFactory)); + } + + /** + * This test does not use partial deserialization. It is used to establish correctness + * of full serialization used in the other tests. + */ + @Test + public void testRoundTripFull() throws TException { + TestStruct ts1 = testData.createTestStruct(1, 2); + + byte[] bytesBinary = serde.serializeBinary(ts1); + byte[] bytesCompact = serde.serializeCompact(ts1); + + TestStruct ts2 = serde.deserializeBinary(bytesBinary, TestStruct.class); + assertEquals(ts1, ts2); + + ts2 = serde.deserializeCompact(bytesCompact, TestStruct.class); + assertEquals(ts1, ts2); + } + + @Test + public void testPartialSimpleField() throws TException, IOException { + TestStruct ts1 = testData.createTestStruct(1, 1); + assertTrue(ts1.isSetI16Field()); + assertTrue(ts1.isSetI32Field()); + + byte[] bytesBinary = serde.serializeBinary(ts1); + byte[] bytesCompact = serde.serializeCompact(ts1); + + List<String> fieldNames = Arrays.asList("i32Field"); + + TDeserializer partialBinaryDeserializer = + new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory); + TDeserializer partialCompactDeserializer = + new TDeserializer(TestStruct.class, fieldNames, compactProtocolFactory); + + PartialThriftComparer comparer = + new PartialThriftComparer(partialBinaryDeserializer.getMetadata()); + + StringBuilder sb = new StringBuilder(); + TestStruct ts2 = (TestStruct) partialBinaryDeserializer.partialDeserializeObject(bytesBinary); + validatePartialSimpleField(ts1, ts2); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + + ts2 = (TestStruct) partialCompactDeserializer.partialDeserializeObject(bytesCompact); + validatePartialSimpleField(ts1, ts2); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + } + + private void validatePartialSimpleField(TestStruct ts1, TestStruct ts2) { + assertTrue(ts2.toString(), ts2.isSetI32Field()); + assertEquals(ts1.getI32Field(), ts2.getI32Field()); + assertFalse(ts2.isSetI16Field()); + } + + @Test + public void testPartialComplex() throws TException { + int id = 1; + int numItems = 10; + TestStruct ts1 = testData.createTestStruct(id, numItems); + + byte[] bytesBinary = serde.serializeBinary(ts1); + byte[] bytesCompact = serde.serializeCompact(ts1); + + List<String> fieldNames = Arrays.asList( + "byteField", + "i16Field", + "i32Field", + "i64Field", + "doubleField", + "stringField", + + "enumField", + "binaryField", + + // List fields + "byteList", + "i16List", + "i32List", + "i64List", + "doubleList", + "stringList", + "enumList", + "listList", + "setList", + "mapList", + "structList", + "binaryList", + + // Set fields + "byteSet", + "i16Set", + "i32Set", + "i64Set", + "doubleSet", + "stringSet", + "enumSet", + "listSet", + "setSet", + "mapSet", + "structSet", + "binarySet", + + // Map fields + "byteMap", + "i16Map", + "i32Map", + "i64Map", + "doubleMap", + "stringMap", + "enumMap", + "listMap", + "setMap", + "mapMap", + "structMap", + "binaryMap", + + // Struct field + "structField" + ); + StringBuilder sb = new StringBuilder(); + TDeserializer partialBinaryDeserializer = + new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory); + TDeserializer partialCompactDeserializer = + new TDeserializer(TestStruct.class, fieldNames, compactProtocolFactory); + PartialThriftComparer comparer = + new PartialThriftComparer(partialBinaryDeserializer.getMetadata()); + + TestStruct ts2 = (TestStruct) partialBinaryDeserializer.partialDeserializeObject(bytesBinary); + validatePartialComplex(ts1, ts2, id, numItems); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + + ts2 = (TestStruct) partialCompactDeserializer.partialDeserializeObject(bytesCompact); + validatePartialComplex(ts1, ts2, id, numItems); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + } + + private void validatePartialComplex(TestStruct ts1, TestStruct ts2, int id, int numItems) { + + // Validate primitive fields. + assertTrue(ts2.toString(), ts2.isSetByteField()); + assertEquals(ts1.getByteField(), ts2.getByteField()); + + assertTrue(ts2.isSetI16Field()); + assertEquals(ts1.getI16Field(), ts2.getI16Field()); + + assertTrue(ts2.isSetI32Field()); + assertEquals(ts1.getI32Field(), ts2.getI32Field()); + + assertTrue(ts2.isSetI64Field()); + assertEquals(ts1.getI64Field(), ts2.getI64Field()); + + assertTrue(ts2.isSetDoubleField()); + assertEquals(ts1.getDoubleField(), ts2.getDoubleField(), 0.0001); + + assertTrue(ts2.isSetStringField()); + assertEquals(ts1.getStringField(), ts2.getStringField()); + + assertTrue(ts2.isSetEnumField()); + assertEquals(ts1.getEnumField(), ts2.getEnumField()); + + assertTrue(ts2.isSetBinaryField()); + assertArrayEquals(ts1.getBinaryField(), ts2.getBinaryField()); + + // Validate list fields. + validateList(ts2.getByteList(), id, numItems); + validateList(ts2.getI16List(), id, numItems); + validateList(ts2.getI32List(), id, numItems); + validateList(ts2.getI64List(), id, numItems); + validateList(ts2.getDoubleList(), id, numItems); + validateStringList(ts2.getStringList(), id, numItems); + validateEnumList(ts2.getEnumList(), id, numItems); + + validateListOfList(ts2.getListList(), id, numItems); + validateListOfSet(ts2.getSetList(), id, numItems); + validateListOfMap(ts2.getMapList(), id, numItems); + validateListOfStruct(ts2.getStructList(), id, numItems); + validateListOfBinary(ts2.getBinaryList(), id, numItems); + + // Validate set fields. + validateSet(ts2.getByteSet(), Byte.class, numItems); + validateSet(ts2.getI16Set(), Short.class, numItems); + validateSet(ts2.getI32Set(), Integer.class, numItems); + validateSet(ts2.getI64Set(), Long.class, numItems); + validateSet(ts2.getDoubleSet(), Double.class, numItems); + validateStringSet(ts2.getStringSet(), id, numItems); + validateEnumSet(ts2.getEnumSet(), id, numItems); + + validateSetOfList(ts2.getListSet(), id, numItems); + validateSetOfSet(ts2.getSetSet(), id, numItems); + validateSetOfMap(ts2.getMapSet(), id, numItems); + validateSetOfStruct(ts2.getStructSet(), id, numItems); + validateSetOfBinary(ts2.getBinarySet(), id, numItems); + + // Validate map fields. + validateMap(ts2.getByteMap(), Byte.class, numItems); + validateMap(ts2.getI16Map(), Short.class, numItems); + validateMap(ts2.getI32Map(), Integer.class, numItems); + validateMap(ts2.getI64Map(), Long.class, numItems); + validateMap(ts2.getDoubleMap(), Double.class, numItems); + validateStringMap(ts2.getStringMap(), id, numItems); + validateEnumMap(ts2.getEnumMap(), id, numItems); + + validateMapOfList(ts2.getListMap(), id, numItems); + validateMapOfSet(ts2.getSetMap(), id, numItems); + validateMapOfMap(ts2.getMapMap(), id, numItems); + validateMapOfStruct(ts2.getStructMap(), id, numItems); + validateMapOfBinary(ts2.getBinaryMap(), id, numItems); + + // Validate struct field. + assertEquals(testData.createSmallStruct(id), ts2.getStructField()); + } + + private void validateNotNullAndNotEmpty(Collection<?> collection, int numItems) { + assertNotNull(collection); + assertEquals(numItems, collection.size()); + } + + // ---------------------------------------------------------------------- + // List validation helpers. + + private <V extends Number> void validateList(List<V> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + assertEquals(i, list.get(i).longValue()); + } + } + + private void validateStringList(List<String> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + for (int i = 0; i < numItems; i++) { + assertEquals(Integer.valueOf(i), Integer.valueOf(list.get(i))); + } + } + + private void validateEnumList(List<TstEnum> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + for (int i = 0; i < numItems; i++) { + assertEquals(TstEnum.E_ONE, list.get(i)); + } + } + + private <V extends Number> void validateListOfList(List<List<V>> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + validateList(list.get(i), id, numItems); + } + } + + private <V extends Number> void validateListOfSet(List<Set<V>> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + Set<V> set = list.get(i); + for (int j = 0; j < numItems; j++) { + assertTrue(set.contains(j)); + } + } + } + + private <V extends Number> void validateListOfMap( + List<Map<String, V>> list, int id, int numItems) { + + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + Map<String, V> map = list.get(i); + for (int j = 0; j < numItems; j++) { + String key = Integer.toString(j); + assertTrue(map.containsKey(key)); + assertEquals(j, map.get(key)); + } + } + } + + private void validateListOfStruct(List<SmallStruct> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + SmallStruct ss = testData.createSmallStruct(i); + for (int j = 0; j < numItems; j++) { + assertEquals(ss, list.get(i)); + } + } + } + + private void validateListOfBinary(List<ByteBuffer> list, int id, int numItems) { + validateNotNullAndNotEmpty(list, numItems); + + for (int i = 0; i < numItems; i++) { + ByteBuffer bb = ByteBuffer.wrap(testData.BYTES); + assertTrue(bb.compareTo(list.get(i)) == 0); + } + } + + // ---------------------------------------------------------------------- + // Set validation helpers. + + private <V extends Number> void validateSet(Set<V> set, Class<V> clasz, int numItems) { + validateNotNullAndNotEmpty(set, numItems); + + for (int i = 0; i < numItems; i++) { + if (clasz == Byte.class) { + assertTrue(set.contains((byte)i)); + } else if (clasz == Short.class) { + assertTrue(set.contains((short)i)); + } else if (clasz == Integer.class) { + assertTrue(set.contains(i)); + } else if (clasz == Long.class) { + assertTrue(set.contains((long)i)); + } else if (clasz == Double.class) { + assertTrue(set.contains((double)i)); + } + } + } + + private void validateStringSet(Set<String> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, numItems); + + for (int i = 0; i < numItems; i++) { + assertTrue(set.contains(Integer.toString(i))); + } + } + + private void validateEnumSet(Set<TstEnum> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, 1); + + assertTrue(set.contains(TstEnum.E_ONE)); + } + + private void validateSetOfList(Set<List<Integer>> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, 1); + + List<Integer> list = new ArrayList<>(numItems); + for (int i = 0; i < numItems; i++) { + list.add(i); + } + + assertTrue(set.contains(list)); + } + + private void validateSetOfSet(Set<Set<Integer>> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, 1); + + Set<Integer> setElt = new HashSet<>(); + for (int i = 0; i < numItems; i++) { + setElt.add(i); + } + + assertTrue(set.contains(setElt)); + } + + private void validateSetOfMap(Set<Map<String, Integer>> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, 1); + + Map<String, Integer> map = new HashMap<>(); + for (int i = 0; i < numItems; i++) { + map.put(Integer.toString(i), i); + } + + assertTrue(set.contains(map)); + } + + private void validateSetOfStruct(Set<SmallStruct> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, numItems); + + for (int i = 0; i < numItems; i++) { + SmallStruct ss = testData.createSmallStruct(i); + assertTrue(set.contains(ss)); + } + } + + private void validateSetOfBinary(Set<ByteBuffer> set, int id, int numItems) { + validateNotNullAndNotEmpty(set, 1); + + for (ByteBuffer b : set) { + ByteBuffer bb = ByteBuffer.wrap(testData.BYTES); + assertEquals(0, bb.compareTo(b)); + } + } + + // ---------------------------------------------------------------------- + // Map validation helpers. + + void validateNotNullAndNotEmpty(Map<?, ?> map, int numItems) { + assertNotNull(map); + assertEquals(numItems, map.size()); + } + + private <V extends Number> void validateMap(Map<V, V> map, Class<V> clasz, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + for (int i = 0; i < numItems; i++) { + if (clasz == Byte.class) { + assertTrue(map.containsKey((byte)i)); + assertEquals((byte) i, map.get((byte) i)); + } else if (clasz == Short.class) { + assertTrue(map.containsKey((short)i)); + assertEquals((short) i, map.get((short) i)); + } else if (clasz == Integer.class) { + assertTrue(map.containsKey(i)); + assertEquals(i, map.get(i)); + } else if (clasz == Long.class) { + assertTrue(map.containsKey((long)i)); + assertEquals((long) i, map.get((long) i)); + } else if (clasz == Double.class) { + assertTrue(map.containsKey((double)i)); + assertEquals((double) i, map.get((double) i)); + } + } + } + + private void validateStringMap(Map<String, String> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + for (int i = 0; i < numItems; i++) { + String key = Integer.toString(i); + assertTrue(map.containsKey(key)); + assertEquals(key, map.get(key)); + } + } + + private void validateEnumMap(Map<TstEnum, TstEnum> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, 1); + + assertTrue(map.containsKey(TstEnum.E_ONE)); + assertEquals(TstEnum.E_ONE, map.get(TstEnum.E_ONE)); + } + + private void validateMapOfList(Map<Integer, List<Integer>> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + List<Integer> list = new ArrayList<>(numItems); + for (int i = 0; i < numItems; i++) { + list.add(i); + } + + for (int i = 0; i < numItems; i++) { + assertTrue(map.containsKey(i)); + assertEquals(list, map.get(i)); + } + } + + private void validateMapOfSet(Map<Integer, Set<Integer>> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + Set<Integer> setElt = new HashSet<>(); + for (int i = 0; i < numItems; i++) { + setElt.add(i); + } + + for (int i = 0; i < numItems; i++) { + assertTrue(map.containsKey(i)); + assertEquals(setElt, map.get(i)); + } + } + + private void validateMapOfMap(Map<Integer, Map<Integer, Integer>> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + Map<Integer, Integer> mapElt = new HashMap<>(); + for (int i = 0; i < numItems; i++) { + mapElt.put(i, i); + } + + for (int i = 0; i < numItems; i++) { + assertTrue(map.containsKey(i)); + assertEquals(mapElt, map.get(i)); + } + } + + private void validateMapOfStruct(Map<SmallStruct, SmallStruct> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + for (int i = 0; i < numItems; i++) { + SmallStruct ss = testData.createSmallStruct(i); + assertTrue(map.containsKey(ss)); + assertEquals(ss, map.get(ss)); + } + } + + private void validateMapOfBinary(Map<Integer, ByteBuffer> map, int id, int numItems) { + validateNotNullAndNotEmpty(map, numItems); + + for (int i = 0; i < numItems; i++) { + ByteBuffer bb = ByteBuffer.wrap(testData.BYTES); + assertTrue(map.containsKey(i)); + assertEquals(0, bb.compareTo(map.get(i))); + } + } +} diff --git a/lib/java/test/org/apache/thrift/partial/EnumCacheTest.java b/lib/java/test/org/apache/thrift/partial/EnumCacheTest.java new file mode 100644 index 000000000..394dcc2e9 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/EnumCacheTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.partial.ExceptionAsserts; + +import org.apache.thrift.TEnum; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Test ThriftCodec serializes and deserializes thrift objects correctly. + */ +public class EnumCacheTest { + + enum TestEnum implements TEnum { + Alice(-1), + Bob(0), + Charlie(1); + + private int value; + + TestEnum(int value) { + this.value = value; + } + + @Override + public int getValue() { + return this.value; + } + } + + static class NotEnum implements TEnum { + + public static final NotEnum Alice = new NotEnum(-11); + public static final NotEnum Bob = new NotEnum(10); + public static final NotEnum Charlie = new NotEnum(11); + + private static final NotEnum[] allValues = { Alice, Bob, Charlie }; + + private int value; + + private NotEnum(int value) { + this.value = value; + } + + public static TEnum[] values() { + return NotEnum.allValues; + } + + @Override + public int getValue() { + return this.value; + } + + @Override + public String toString() { + return String.format("NotEnum : %d", this.value); + } + } + + @Test + public void testArgChecks() { + EnumCache cache = new EnumCache(); + + // Should not throw. + cache.get(TestEnum.class, 0); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'enumClass' must not be null", + () -> cache.get(null, 1)); + } + + @Test + public void testGet() { + EnumCache cache = new EnumCache(); + + assertEquals(TestEnum.Alice, cache.get(TestEnum.class, -1)); + assertEquals(TestEnum.Bob, cache.get(TestEnum.class, 0)); + assertEquals(TestEnum.Charlie, cache.get(TestEnum.class, 1)); + + assertEquals(NotEnum.Alice, cache.get(NotEnum.class, -11)); + assertEquals(NotEnum.Bob, cache.get(NotEnum.class, 10)); + assertEquals(NotEnum.Charlie, cache.get(NotEnum.class, 11)); + } + + @Test + public void testGetInvalid() { + EnumCache cache = new EnumCache(); + + assertNull(cache.get(TestEnum.class, 42)); + assertNull(cache.get(NotEnum.class, 42)); + } +} diff --git a/lib/java/test/org/apache/thrift/partial/ExceptionAsserts.java b/lib/java/test/org/apache/thrift/partial/ExceptionAsserts.java new file mode 100644 index 000000000..239903cf4 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ExceptionAsserts.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public final class ExceptionAsserts { + private ExceptionAsserts() {} + + @FunctionalInterface + public interface CodeThatMayThrow { + void run() throws Exception; + } + + /** + * Asserts that the given code throws an exception of the given type + * and that the exception message contains the given sub-message. + * + * Usage: + * + * ExceptionAsserts.assertThrows( + * IllegalArgumentException.class, + * "'nullArg' must not be null", + * () -> Preconditions.checkNotNull(null, "nullArg")); + * + * Note: JUnit 5 has similar functionality but it will be a long time before + * we move to that framework because of significant differences and lack of + * backward compatibility for some JUnit rules. + */ + public static <E extends Exception> void assertThrows( + Class<E> expectedExceptionClass, + String partialMessage, + CodeThatMayThrow code) { + + Exception thrownException = null; + + try { + code.run(); + } catch (Exception e) { + if (expectedExceptionClass.isAssignableFrom(e.getClass())) { + + thrownException = e; + + if (partialMessage != null) { + String msg = e.getMessage(); + assertNotNull( + String.format("Exception message is null, expected to contain: '%s'", partialMessage), + msg); + assertTrue( + String.format("Exception message '%s' does not contain: '%s'", msg, partialMessage), + msg.contains(partialMessage)); + } + } else { + fail(String.format( + "Expected exception of type %s but got %s", + expectedExceptionClass.getName(), + e.getClass().getName())); + } + } + + if (thrownException == null) { + fail(String.format( + "Expected exception of type %s but got none", + expectedExceptionClass.getName())); + } + } + + public static <E extends Exception> void assertThrows( + Class<E> expectedExceptionClass, + CodeThatMayThrow code) { + assertThrows(expectedExceptionClass, null, code); + } +} diff --git a/lib/java/test/org/apache/thrift/partial/PartialThriftComparerTest.java b/lib/java/test/org/apache/thrift/partial/PartialThriftComparerTest.java new file mode 100644 index 000000000..e1209d733 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/PartialThriftComparerTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.*; + +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TException; +import org.apache.thrift.partial.TestStruct; +import org.apache.thrift.partial.ThriftField; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class PartialThriftComparerTest { + + private ThriftSerDe serde; + private PartialThriftTestData testData = new PartialThriftTestData(); + + public PartialThriftComparerTest() throws TException { + this.serde = new ThriftSerDe(); + } + + @Test + public void testCompareSimple() throws TException, IOException { + TestStruct ts1 = testData.createTestStruct(1, 1); + assertTrue(ts1.isSetI16Field()); + assertTrue(ts1.isSetI32Field()); + + byte[] bytesBinary = serde.serializeBinary(ts1); + byte[] bytesCompact = serde.serializeCompact(ts1); + + List<String> fieldNames = Arrays.asList("i32Field"); + TDeserializer partialBinaryDeser = + new TDeserializer(TestStruct.class, fieldNames, new TBinaryProtocol.Factory()); + TDeserializer partialCompactDeser = + new TDeserializer(TestStruct.class, fieldNames, new TCompactProtocol.Factory()); + + ThriftMetadata.ThriftStruct metadata = partialBinaryDeser.getMetadata(); + PartialThriftComparer comparer = new PartialThriftComparer(metadata); + + StringBuilder sb = new StringBuilder(); + TestStruct ts2 = (TestStruct) partialBinaryDeser.partialDeserializeObject(bytesBinary); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + + ts2 = (TestStruct) partialCompactDeser.partialDeserializeObject(bytesCompact); + if (!comparer.areEqual(ts1, ts2, sb)) { + fail(sb.toString()); + } + } +} diff --git a/lib/java/test/org/apache/thrift/partial/PartialThriftTestData.java b/lib/java/test/org/apache/thrift/partial/PartialThriftTestData.java new file mode 100644 index 000000000..6376075da --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/PartialThriftTestData.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Helpers for creating test data related to partial deserialization. + */ +public class PartialThriftTestData { + + public final byte[] BYTES = new byte[] { 1, 2, 3 }; + + public SmallStruct createSmallStruct(int id) { + return new SmallStruct() + .setByteField((byte) id) + .setI16Field((short) id) + .setI32Field(id) + .setI64Field(id) + .setDoubleField(id) + .setStringField(Integer.toString(id)) + .setEnumField(TstEnum.E_ONE); + } + + public TestStruct createTestStruct(int id, int numItems) { + + TestStruct ts = new TestStruct() + .setByteField((byte) id) + .setI16Field((short) id) + .setI32Field(id) + .setI64Field(id) + .setDoubleField(id) + .setStringField(Integer.toString(id)) + .setEnumField(TstEnum.E_ONE) + .setBinaryField(BYTES) + .setStructField(createSmallStruct(id)); + + initListFields(ts, id, numItems); + initSetFields(ts, id, numItems); + initMapFields(ts, id, numItems); + + return ts; + } + + public void initListFields(TestStruct ts, int id, int numItems) { + List<Byte> byteList = new ArrayList<>(numItems); + List<Short> i16List = new ArrayList<>(numItems); + List<Integer> i32List = new ArrayList<>(numItems); + List<Long> i64List = new ArrayList<>(numItems); + List<Double> doubleList = new ArrayList<>(numItems); + List<String> stringList = new ArrayList<>(numItems); + List<TstEnum> enumList = new ArrayList<>(numItems); + + List<List<Integer>> listList = new ArrayList<>(numItems); + List<Set<Integer>> setList = new ArrayList<>(numItems); + List<Map<String, Integer>> mapList = new ArrayList<>(numItems); + List<SmallStruct> structList = new ArrayList<>(numItems); + List<ByteBuffer> binaryList = new ArrayList<>(numItems); + + for (int i = 0; i < numItems; i++) { + byteList.add((byte) i); + i16List.add((short) i); + i32List.add(i); + i64List.add((long)i); + doubleList.add((double) i); + stringList.add(Integer.toString(i)); + enumList.add(TstEnum.E_ONE); + structList.add(createSmallStruct(i)); + binaryList.add(ByteBuffer.wrap(BYTES)); + + List<Integer> listItem = new ArrayList<>(numItems); + listList.add(listItem); + + Set<Integer> setItem = new HashSet<>(); + setList.add(setItem); + + Map<String, Integer> mapItem = new HashMap<>(); + mapList.add(mapItem); + + for (int j = 0; j < numItems; j++) { + listItem.add(j); + setItem.add(j); + mapItem.put(Integer.toString(j), j); + } + } + + ts.setByteList(byteList) + .setI16List(i16List) + .setI32List(i32List) + .setI64List(i64List) + .setDoubleList(doubleList) + .setStringList(stringList) + .setEnumList(enumList) + .setListList(listList) + .setSetList(setList) + .setMapList(mapList) + .setStructList(structList) + .setBinaryList(binaryList); + } + + public void initSetFields(TestStruct ts, int id, int numItems) { + Set<Byte> byteSet = new HashSet<>(); + Set<Short> i16Set = new HashSet<>(); + Set<Integer> i32Set = new HashSet<>(); + Set<Long> i64Set = new HashSet<>(); + Set<Double> doubleSet = new HashSet<>(); + Set<String> stringSet = new HashSet<>(); + Set<TstEnum> enumSet = new HashSet<>(); + + Set<List<Integer>> listSet = new HashSet<>(); + Set<Set<Integer>> setSet = new HashSet<>(); + Set<Map<String, Integer>> mapSet = new HashSet<>(); + Set<SmallStruct> structSet = new HashSet<>(); + Set<ByteBuffer> binarySet = new HashSet<>(); + + for (int i = 0; i < numItems; i++) { + byteSet.add((byte) i); + i16Set.add((short) i); + i32Set.add(i); + i64Set.add((long)i); + doubleSet.add((double) i); + stringSet.add(Integer.toString(i)); + enumSet.add(TstEnum.E_ONE); + structSet.add(createSmallStruct(i)); + binarySet.add(ByteBuffer.wrap(BYTES)); + + List<Integer> listItem = new ArrayList<>(numItems); + Set<Integer> setItem = new HashSet<>(); + Map<String, Integer> mapItem = new HashMap<>(); + + for (int j = 0; j < numItems; j++) { + setItem.add(j); + listItem.add(j); + mapItem.put(Integer.toString(j), j); + } + + listSet.add(listItem); + setSet.add(setItem); + mapSet.add(mapItem); + } + + ts.setByteSet(byteSet) + .setI16Set(i16Set) + .setI32Set(i32Set) + .setI64Set(i64Set) + .setDoubleSet(doubleSet) + .setStringSet(stringSet) + .setEnumSet(enumSet) + .setListSet(listSet) + .setSetSet(setSet) + .setMapSet(mapSet) + .setStructSet(structSet) + .setBinarySet(binarySet); + } + + public void initMapFields(TestStruct ts, int id, int numItems) { + Map<Byte, Byte> byteMap = new HashMap<>(); + Map<Short, Short> i16Map = new HashMap<>(); + Map<Integer, Integer> i32Map = new HashMap<>(); + Map<Long, Long> i64Map = new HashMap<>(); + Map<Double, Double> doubleMap = new HashMap<>(); + Map<String, String> stringMap = new HashMap<>(); + Map<TstEnum, TstEnum> enumMap = new HashMap<>(); + + Map<Integer, List<Integer>> listMap = new HashMap<>(); + Map<Integer, Set<Integer>> setMap = new HashMap<>(); + Map<Integer, Map<Integer, Integer>> mapMap = new HashMap<>(); + Map<SmallStruct, SmallStruct> structMap = new HashMap<>(); + Map<Integer, ByteBuffer> binaryMap = new HashMap<>(); + + for (int i = 0; i < numItems; i++) { + byteMap.put((byte) i, (byte) i); + i16Map.put((short) i, (short) i); + i32Map.put(i, i); + i64Map.put((long) i, (long) i); + doubleMap.put((double) i, (double) i); + stringMap.put(Integer.toString(i), Integer.toString(i)); + enumMap.put(TstEnum.E_ONE, TstEnum.E_ONE); + structMap.put(createSmallStruct(i), createSmallStruct(i)); + binaryMap.put(i, ByteBuffer.wrap(BYTES)); + + List<Integer> listItem = new ArrayList<>(numItems); + listMap.put(i, listItem); + + Set<Integer> setItem = new HashSet<>(); + setMap.put(i, setItem); + + Map<Integer, Integer> mapItem = new HashMap<>(); + mapMap.put(i, mapItem); + + for (int j = 0; j < numItems; j++) { + listItem.add(j); + setItem.add(j); + mapItem.put(j, j); + } + } + + ts.setByteMap(byteMap) + .setI16Map(i16Map) + .setI32Map(i32Map) + .setI64Map(i64Map) + .setDoubleMap(doubleMap) + .setStringMap(stringMap) + .setEnumMap(enumMap) + .setListMap(listMap) + .setSetMap(setMap) + .setMapMap(mapMap) + .setStructMap(structMap) + .setBinaryMap(binaryMap); + } + + public List<String> allFieldsOfTestStruct() { + return new ArrayList<>( + Arrays.asList( + "byteField", + "i16Field", + "i32Field", + "i64Field", + "doubleField", + "stringField", + "structField.byteField", + "structField.i16Field", + "structField.i32Field", + "structField.i64Field", + "structField.doubleField", + "structField.stringField", + "structField.enumField", + "enumField", + "binaryField", + "byteList", + "i16List", + "i32List", + "i64List", + "doubleList", + "stringList", + "enumList", + "listList", + "setList", + "mapList", + "structList.byteField", + "structList.i16Field", + "structList.i32Field", + "structList.i64Field", + "structList.doubleField", + "structList.stringField", + "structList.enumField", + "binaryList", + "byteSet", + "i16Set", + "i32Set", + "i64Set", + "doubleSet", + "stringSet", + "enumSet", + "listSet", + "setSet", + "mapSet", + "structSet.byteField", + "structSet.i16Field", + "structSet.i32Field", + "structSet.i64Field", + "structSet.doubleField", + "structSet.stringField", + "structSet.enumField", + "binarySet", + "byteMap", + "i16Map", + "i32Map", + "i64Map", + "doubleMap", + "stringMap", + "enumMap", + "listMap", + "setMap", + "mapMap", + "structMap.byteField", + "structMap.i16Field", + "structMap.i32Field", + "structMap.i64Field", + "structMap.doubleField", + "structMap.stringField", + "structMap.enumField", + "binaryMap" + ) + ); + } +} diff --git a/lib/java/test/org/apache/thrift/partial/TFieldDataTest.java b/lib/java/test/org/apache/thrift/partial/TFieldDataTest.java new file mode 100644 index 000000000..0a838e970 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/TFieldDataTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.protocol.TField; +import org.apache.thrift.protocol.TType; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class TFieldDataTest { + + @Test + public void testEncodeStop() { + TField field = new TField("", TType.STOP, (short) 0); + int data = TFieldData.encode(TType.STOP); + + assertEquals(field.type, TFieldData.getType(data)); + assertEquals(field.id, TFieldData.getId(data)); + } + + @Test + public void testEncodeRest() { + for (byte type = 1; type <= 16; type++) { + for (short id = 0; id < Short.MAX_VALUE; id++) { + TField field = new TField("", type, id); + int data = TFieldData.encode(type, id); + + assertEquals(field.type, TFieldData.getType(data)); + assertEquals(field.id, TFieldData.getId(data)); + } + } + } +} diff --git a/lib/java/test/org/apache/thrift/partial/TestData.java b/lib/java/test/org/apache/thrift/partial/TestData.java new file mode 100644 index 000000000..1779346f7 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/TestData.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Frequently used test data items. + */ +public final class TestData { + private TestData() {} + + + // Array data. + public static Object[] nullArray = null; + public static Object[] emptyArray = new Object[0]; + public static Object[] nonEmptyArray = new Object[1]; + + public static byte[] nullByteArray = null; + public static byte[] emptyByteArray = new byte[0]; + public static byte[] nonEmptyByteArray = new byte[1]; + + public static short[] nullShortArray = null; + public static short[] emptyShortArray = new short[0]; + public static short[] nonEmptyShortArray = new short[1]; + + public static int[] nullIntArray = null; + public static int[] emptyIntArray = new int[0]; + public static int[] nonEmptyIntArray = new int[1]; + + public static long[] nullLongArray = null; + public static long[] emptyLongArray = new long[0]; + public static long[] nonEmptyLongArray = new long[1]; + + public static List<Object> nullList = null; + public static List<Object> emptyList = new ArrayList<Object>(); + public static List<Object> validList = Arrays.asList(new Object[1]); +} diff --git a/lib/java/test/org/apache/thrift/partial/ThriftFieldTest.java b/lib/java/test/org/apache/thrift/partial/ThriftFieldTest.java new file mode 100644 index 000000000..a6d5655e4 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ThriftFieldTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.*; + +import org.apache.thrift.partial.ExceptionAsserts; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.*; + +public class ThriftFieldTest { + + @Test + public void testArgChecks() { + ThriftField test; + List<ThriftField> testFields; + + // Should not throw. + test = new ThriftField("foo"); + test = new ThriftField("foo", Arrays.asList(new ThriftField("bar"))); + testFields = ThriftField.fromNames(Arrays.asList("foo")); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'name' must not be null", + () -> new ThriftField(null, Collections.emptyList())); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'fields' must not be null", + () -> new ThriftField("foo", null)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'fieldNames' must not be null", + () -> ThriftField.fromNames(null)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'fieldNames' must have at least one element", + () -> ThriftField.fromNames(Collections.emptyList())); + } + + @Test + public void testFromNames() { + List<String> fieldNames = Arrays.asList( + "f1", + "f2.f21", + "f3.f31.f311", + "f3.f32.f321", + "f3.f32.f322" + ); + + List<ThriftField> testFields = ThriftField.fromNames(fieldNames); + + assertEquals(3, testFields.size()); + ThriftField f1 = testFields.get(0); + ThriftField f2 = testFields.get(1); + ThriftField f3 = testFields.get(2); + assertEquals("f1", f1.name); + assertEquals("f2", f2.name); + assertEquals("f3", f3.name); + + assertEquals(0, f1.fields.size()); + assertEquals(1, f2.fields.size()); + assertEquals(2, f3.fields.size()); + + ThriftField f21 = f2.fields.get(0); + ThriftField f31 = f3.fields.get(0); + ThriftField f32 = f3.fields.get(1); + assertEquals("f21", f21.name); + assertEquals("f31", f31.name); + assertEquals("f32", f32.name); + + assertEquals(0, f21.fields.size()); + assertEquals(1, f31.fields.size()); + assertEquals(2, f32.fields.size()); + + ThriftField f311 = f31.fields.get(0); + ThriftField f321 = f32.fields.get(0); + ThriftField f322 = f32.fields.get(1); + assertEquals("f311", f311.name); + assertEquals("f321", f321.name); + assertEquals("f322", f322.name); + + assertEquals(0, f311.fields.size()); + assertEquals(0, f321.fields.size()); + assertEquals(0, f322.fields.size()); + } + + @Test + public void testEquality() { + List<String> fieldNames = Arrays.asList( + "f1", + "f2.f21", + "f3.f31.f311", + "f3.f32.f321", + "f3.f32.f322" + ); + + List<ThriftField> testFields = ThriftField.fromNames(fieldNames); + List<ThriftField> testFields2 = testFields; + + assertSame(testFields, testFields2); + assertEquals(testFields, testFields2); + + List<ThriftField> testFields3 = ThriftField.fromNames(fieldNames); + assertNotSame(testFields, testFields3); + assertEquals(testFields, testFields3); + assertEquals(testFields.hashCode(), testFields3.hashCode()); + + List<String> fieldNamesDiff = Arrays.asList( + "f1", + "f2.f21", + "f3.f31.f311", + "f3.f32.f323", + "f3.f32.f322" + ); + + List<ThriftField> testFields4 = ThriftField.fromNames(fieldNamesDiff); + assertNotSame(testFields, testFields4); + assertNotEquals(testFields, testFields4); + assertNotEquals(testFields.hashCode(), testFields4.hashCode()); + } +} diff --git a/lib/java/test/org/apache/thrift/partial/ThriftMetadataTest.java b/lib/java/test/org/apache/thrift/partial/ThriftMetadataTest.java new file mode 100644 index 000000000..acc53c8a6 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ThriftMetadataTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.*; + +import org.apache.thrift.partial.TestStruct; +import org.apache.thrift.partial.ThriftField; +import org.apache.thrift.partial.ExceptionAsserts; + +import org.apache.thrift.TBase; +import org.apache.thrift.meta_data.EnumMetaData; +import org.apache.thrift.meta_data.FieldValueMetaData; +import org.apache.thrift.meta_data.ListMetaData; +import org.apache.thrift.meta_data.MapMetaData; +import org.apache.thrift.meta_data.SetMetaData; +import org.apache.thrift.meta_data.StructMetaData; +import org.apache.thrift.protocol.TType; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +public class ThriftMetadataTest { + + private PartialThriftTestData testData = new PartialThriftTestData(); + + @Test + public void testArgChecks() { + // Should not throw. + List<ThriftField> testFields = ThriftField.fromNames(Arrays.asList("byteField")); + ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, testFields); + + // Verify it throws correctly. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'clasz' must not be null", + () -> ThriftMetadata.ThriftStruct.fromFields(null, testFields)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'fields' must not be null", + () -> ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, null)); + } + + @Test + public void testThriftStructOf() { + ThriftMetadata.ThriftStruct testStruct = ThriftMetadata.ThriftStruct.of(TestStruct.class); + assertEquals(45, testStruct.fields.keySet().size()); + validateFieldMetadata(testStruct, 1, "byteField", TType.BYTE); + validateFieldMetadata(testStruct, 2, "i16Field", TType.I16); + validateFieldMetadata(testStruct, 3, "i32Field", TType.I32); + validateFieldMetadata(testStruct, 4, "i64Field", TType.I64); + validateFieldMetadata(testStruct, 5, "doubleField", TType.DOUBLE); + validateFieldMetadata(testStruct, 6, "stringField", TType.STRING); + validateFieldMetadata(testStruct, 7, "enumField", TType.ENUM); + validateFieldMetadata(testStruct, 8, "binaryField", TType.STRING); + + validateListFieldMetadata(testStruct, 10, "byteList", TType.BYTE); + validateSetFieldMetadata(testStruct, 35, "stringSet", TType.STRING); + validateMapFieldMetadata(testStruct, 61, "binaryMap", TType.I32, TType.STRING); + } + + @Test + public void testUnion() { + ThriftMetadata.ThriftStruct structWithUnions = + ThriftMetadata.ThriftStruct.of(StructWithUnions.class); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 1, "intValue"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 2, "smallStruct"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 3, "simpleUnion"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 4, "unionList"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 5, "unionSet"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 6, "keyUnionMap"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 7, "valUnionMap"); + validateBasicFieldMetadata(structWithUnions, StructWithUnions.class, 8, "unionMap"); + + ThriftMetadata.ThriftStructBase smallStructMetadata = + (ThriftMetadata.ThriftStructBase) structWithUnions.fields.get(2); + assertFalse(smallStructMetadata.isUnion()); + + ThriftMetadata.ThriftStructBase simpleUnionMetadata = + (ThriftMetadata.ThriftStructBase) structWithUnions.fields.get(3); + assertTrue(simpleUnionMetadata.isUnion()); + + ThriftMetadata.ThriftList unionListMetadata = + (ThriftMetadata.ThriftList) structWithUnions.fields.get(4); + assertTrue(unionListMetadata.hasUnion()); + + ThriftMetadata.ThriftSet unionSetMetadata = + (ThriftMetadata.ThriftSet) structWithUnions.fields.get(5); + assertTrue(unionSetMetadata.hasUnion()); + + ThriftMetadata.ThriftMap keyUnionMapMetadata = + (ThriftMetadata.ThriftMap) structWithUnions.fields.get(6); + assertTrue(keyUnionMapMetadata.hasUnion()); + + ThriftMetadata.ThriftMap valUnionMapMetadata = + (ThriftMetadata.ThriftMap) structWithUnions.fields.get(7); + assertTrue(valUnionMapMetadata.hasUnion()); + + ThriftMetadata.ThriftMap unionMapMetadata = + (ThriftMetadata.ThriftMap) structWithUnions.fields.get(8); + assertTrue(unionMapMetadata.hasUnion()); + } + + private ThriftMetadata.ThriftObject validateBasicFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + int id, + String fieldName) { + return validateBasicFieldMetadata(testStruct, TestStruct.class, id, fieldName); + } + + private ThriftMetadata.ThriftObject validateBasicFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + Class<? extends TBase> clazz, + int id, + String fieldName) { + + assertNotNull(testStruct); + assertNull(testStruct.parent); + assertEquals(clazz, ((StructMetaData) testStruct.data.valueMetaData).structClass); + assertTrue(testStruct.fields.containsKey(id)); + + ThriftMetadata.ThriftObject fieldMetadata = + (ThriftMetadata.ThriftObject) testStruct.fields.get(id); + assertEquals(testStruct, fieldMetadata.parent); + + assertEquals(id, fieldMetadata.fieldId.getThriftFieldId()); + assertEquals(fieldName, fieldMetadata.fieldId.getFieldName()); + assertEquals(fieldName, fieldMetadata.data.fieldName); + + assertEquals("root ==> " + fieldName, fieldMetadata.toString()); + + return fieldMetadata; + } + + private void validateBasicFieldValueMetadata( + ThriftMetadata.ThriftObject fieldMetadata, + String fieldName, + byte ttype) { + + assertEquals(ttype, fieldMetadata.data.valueMetaData.type); + assertEquals(getMetaDataClassForTType(ttype), fieldMetadata.data.valueMetaData.getClass()); + Class<? extends ThriftMetadata.ThriftObject> fieldMetadataClass = getClassForTType(ttype); + assertEquals(fieldMetadataClass, fieldMetadata.getClass()); + if (fieldMetadataClass == ThriftMetadata.ThriftPrimitive.class) { + ThriftMetadata.ThriftPrimitive primitive + = (ThriftMetadata.ThriftPrimitive) fieldMetadata; + if (fieldName.startsWith("binary") && (ttype == TType.STRING)) { + assertTrue(primitive.isBinary()); + } else { + assertFalse(primitive.isBinary()); + } + } + } + + private void validateFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + int id, + String fieldName, + byte ttype) { + + ThriftMetadata.ThriftObject fieldMetadata = + validateBasicFieldMetadata(testStruct, id, fieldName); + validateBasicFieldValueMetadata(fieldMetadata, fieldName, ttype); + } + + private void validateListFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + int id, + String fieldName, + byte ttype) { + + ThriftMetadata.ThriftObject fieldMetadata = + validateBasicFieldMetadata(testStruct, id, fieldName); + validateBasicFieldValueMetadata(fieldMetadata, fieldName, TType.LIST); + + ThriftMetadata.ThriftList thriftList = (ThriftMetadata.ThriftList) fieldMetadata; + ThriftMetadata.ThriftObject elementMetadata = thriftList.elementData; + validateBasicFieldValueMetadata(elementMetadata, fieldName + "_element", ttype); + } + + private void validateSetFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + int id, + String fieldName, + byte ttype) { + + ThriftMetadata.ThriftObject fieldMetadata = + validateBasicFieldMetadata(testStruct, id, fieldName); + validateBasicFieldValueMetadata(fieldMetadata, fieldName, TType.SET); + + ThriftMetadata.ThriftSet thriftSet = (ThriftMetadata.ThriftSet) fieldMetadata; + ThriftMetadata.ThriftObject elementMetadata = thriftSet.elementData; + validateBasicFieldValueMetadata(elementMetadata, fieldName + "_element", ttype); + } + + private void validateMapFieldMetadata( + ThriftMetadata.ThriftStruct testStruct, + int id, + String fieldName, + byte keyType, + byte valueType) { + + ThriftMetadata.ThriftObject fieldMetadata = + validateBasicFieldMetadata(testStruct, id, fieldName); + validateBasicFieldValueMetadata(fieldMetadata, fieldName, TType.MAP); + + ThriftMetadata.ThriftMap thriftMap = (ThriftMetadata.ThriftMap) fieldMetadata; + ThriftMetadata.ThriftObject keyMetadata = thriftMap.keyData; + ThriftMetadata.ThriftObject valueMetadata = thriftMap.valueData; + validateBasicFieldValueMetadata(keyMetadata, fieldName + "_key", keyType); + validateBasicFieldValueMetadata(valueMetadata, fieldName + "_value", valueType); + } + + private Class<? extends FieldValueMetaData> getMetaDataClassForTType(byte ttype) { + switch (ttype) { + case TType.STRUCT: + return StructMetaData.class; + + case TType.LIST: + return ListMetaData.class; + + case TType.MAP: + return MapMetaData.class; + + case TType.SET: + return SetMetaData.class; + + case TType.ENUM: + return EnumMetaData.class; + + case TType.BOOL: + case TType.BYTE: + case TType.I16: + case TType.I32: + case TType.I64: + case TType.DOUBLE: + case TType.STRING: + return FieldValueMetaData.class; + + default: + throw ThriftMetadata.unsupportedFieldTypeException(ttype); + } + } + + private Class<? extends ThriftMetadata.ThriftObject> getClassForTType(byte ttype) { + switch (ttype) { + case TType.STRUCT: + return ThriftMetadata.ThriftStruct.class; + + case TType.LIST: + return ThriftMetadata.ThriftList.class; + + case TType.MAP: + return ThriftMetadata.ThriftMap.class; + + case TType.SET: + return ThriftMetadata.ThriftSet.class; + + case TType.ENUM: + return ThriftMetadata.ThriftEnum.class; + + case TType.BOOL: + case TType.BYTE: + case TType.I16: + case TType.I32: + case TType.I64: + case TType.DOUBLE: + case TType.STRING: + return ThriftMetadata.ThriftPrimitive.class; + + default: + throw ThriftMetadata.unsupportedFieldTypeException(ttype); + } + } +} diff --git a/lib/java/test/org/apache/thrift/partial/ThriftSerDe.java b/lib/java/test/org/apache/thrift/partial/ThriftSerDe.java new file mode 100644 index 000000000..361c32c60 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ThriftSerDe.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.TBase; +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TException; +import org.apache.thrift.TSerializer; +import org.apache.thrift.protocol.TBinaryProtocol; +import org.apache.thrift.protocol.TCompactProtocol; + +public class ThriftSerDe { + private TSerializer binarySerializer; + private TSerializer compactSerializer; + private TDeserializer binaryDeserializer; + private TDeserializer compactDeserializer; + + public ThriftSerDe() throws TException { + this.binarySerializer = new TSerializer(new TBinaryProtocol.Factory()); + this.compactSerializer = new TSerializer(new TCompactProtocol.Factory()); + this.binaryDeserializer = new TDeserializer(new TBinaryProtocol.Factory()); + this.compactDeserializer = new TDeserializer(new TCompactProtocol.Factory()); + } + + public byte[] serializeBinary(TBase obj) throws TException { + return binarySerializer.serialize(obj); + } + + public byte[] serializeCompact(TBase obj) throws TException { + return compactSerializer.serialize(obj); + } + + public <T extends TBase> T deserializeBinary(byte[] bytes, Class<T> clazz) throws TException { + T instance = this.newInstance(clazz); + binaryDeserializer.deserialize(instance, bytes); + return clazz.cast(instance); + } + + public <T extends TBase> T deserializeCompact(byte[] bytes, Class<T> clazz) throws TException { + T instance = this.newInstance(clazz); + compactDeserializer.deserialize(instance, bytes); + return clazz.cast(instance); + } + + private <T extends TBase> T newInstance(Class<T> clazz) { + T instance = null; + try { + instance = clazz.newInstance(); + } catch (InstantiationException e) { + } catch (IllegalAccessException e) { + } + return clazz.cast(instance); + } +} diff --git a/lib/java/test/org/apache/thrift/partial/ThriftStructProcessorTest.java b/lib/java/test/org/apache/thrift/partial/ThriftStructProcessorTest.java new file mode 100644 index 000000000..d4ab92509 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ThriftStructProcessorTest.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import static org.junit.Assert.*; + +import org.apache.thrift.partial.TestStruct; +import org.apache.thrift.partial.ThriftField; +import org.apache.thrift.partial.ThriftMetadata; +import org.apache.thrift.partial.TstEnum; + +import org.apache.thrift.TBase; +import org.apache.thrift.TException; +import org.apache.thrift.TFieldIdEnum; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class ThriftStructProcessorTest { + + private PartialThriftTestData testData = new PartialThriftTestData(); + + @Test + public void testStruct() throws TException { + List<ThriftField> fields = ThriftField.fromNames(Arrays.asList("i32Field")); + ThriftMetadata.ThriftStruct metadata = + ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, fields); + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewStruct(metadata); + assertNotNull(instance); + assertTrue(instance instanceof TBase); + assertTrue(instance instanceof TestStruct); + + Object instance2 = processor.prepareStruct(instance); + assertSame(instance, instance2); + } + + @Test + public void testList() throws TException { + final int numItems = 10; + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewList(numItems); + assertNotNull(instance); + assertTrue(instance instanceof Object[]); + + Object[] items = (Object[]) instance; + for (int i = 0; i < numItems; i++) { + assertNull(items[i]); + processor.setListElement(instance, i, Integer.valueOf(i)); + assertEquals(i, items[i]); + } + + assertTrue(processor.prepareList(instance) instanceof List<?>); + } + + @Test + public void testMap() throws TException { + final int numItems = 10; + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewMap(numItems); + assertNotNull(instance); + assertTrue(instance instanceof Map<?, ?>); + + Map<Object, Object> items = (Map<Object, Object>) instance; + int ignoredIndex = -1; + for (int i = 0; i < numItems; i++) { + assertNull(items.get(i)); + processor.setMapElement(instance, ignoredIndex, Integer.valueOf(i), Integer.valueOf(i)); + assertEquals(i, items.get(i)); + } + + assertTrue(processor.prepareMap(instance) instanceof Map<?, ?>); + } + + @Test + public void testSet() throws TException { + final int numItems = 10; + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewSet(numItems); + assertNotNull(instance); + assertTrue(instance instanceof HashSet<?>); + + Set<?> items = (HashSet<?>) instance; + int ignoredIndex = -1; + + for (int i = 0; i < numItems; i++) { + assertFalse(items.contains(i)); + processor.setSetElement(instance, ignoredIndex, Integer.valueOf(i)); + assertTrue(items.contains(i)); + } + + assertTrue(processor.prepareSet(instance) instanceof Set<?>); + } + + @Test + public void testPrepareEnum() throws TException { + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.prepareEnum(TstEnum.class, 1); + assertNotNull(instance); + assertEquals(TstEnum.E_ONE, instance); + + instance = processor.prepareEnum(TstEnum.class, 2); + assertNotNull(instance); + assertEquals(TstEnum.E_TWO, instance); + } + + @Test + public void testPrepareString() throws TException { + ThriftStructProcessor processor = new ThriftStructProcessor(); + ByteBuffer emptyBuffer = ByteBuffer.wrap(new byte[0]); + Object instance = processor.prepareString(emptyBuffer); + assertNotNull(instance); + assertTrue(instance instanceof String); + assertEquals("", instance); + + String value = "Hello world!"; + ByteBuffer buffer = ByteBuffer.wrap(value.getBytes(StandardCharsets.UTF_8)); + instance = processor.prepareString(buffer); + assertNotNull(instance); + assertTrue(instance instanceof String); + assertEquals(value, instance); + } + + @Test + public void testPrepareBinary() throws TException { + ThriftStructProcessor processor = new ThriftStructProcessor(); + ByteBuffer emptyBuffer = ByteBuffer.wrap(new byte[0]); + Object instance = processor.prepareBinary(emptyBuffer); + assertNotNull(instance); + assertTrue(instance instanceof ByteBuffer); + assertSame(emptyBuffer, instance); + } + + @Test + public void testStructPrimitiveFields() throws TException { + List<ThriftField> fields = ThriftField.fromNames( + Arrays.asList( + "byteField", + "i16Field", + "i32Field", + "i64Field", + "doubleField", + "stringField", + + "enumField", + "binaryField" + )); + + ThriftMetadata.ThriftStruct metadata = + ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, fields); + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewStruct(metadata); + assertNotNull(instance); + assertTrue(instance instanceof TBase); + assertTrue(instance instanceof TestStruct); + + TestStruct struct = (TestStruct) instance; + + // byte + TFieldIdEnum fieldId = findFieldId(metadata, "byteField"); + assertNull(getFieldValue(struct, fieldId)); + processor.setByte(struct, fieldId, (byte) 42); + assertEquals(42, struct.getByteField()); + + // short + fieldId = findFieldId(metadata, "i16Field"); + assertNull(getFieldValue(struct, fieldId)); + processor.setInt16(struct, fieldId, (short) 42); + assertEquals(42, struct.getI16Field()); + + // int + fieldId = findFieldId(metadata, "i32Field"); + assertNull(getFieldValue(struct, fieldId)); + processor.setInt32(struct, fieldId, 42); + assertEquals(42, struct.getI32Field()); + + // long + fieldId = findFieldId(metadata, "i64Field"); + assertNull(getFieldValue(struct, fieldId)); + processor.setInt64(struct, fieldId, 42L); + assertEquals(42, struct.getI64Field()); + + // binary + fieldId = findFieldId(metadata, "binaryField"); + assertNull(getFieldValue(struct, fieldId)); + byte[] noBytes = new byte[0]; + ByteBuffer emptyBuffer = ByteBuffer.wrap(noBytes); + processor.setBinary(struct, fieldId, emptyBuffer); + assertArrayEquals(noBytes, struct.getBinaryField()); + + // string + fieldId = findFieldId(metadata, "stringField"); + assertNull(getFieldValue(struct, fieldId)); + String value = "Hello world!"; + ByteBuffer buffer = ByteBuffer.wrap(value.getBytes(StandardCharsets.UTF_8)); + processor.setString(struct, fieldId, buffer); + assertEquals(value, struct.getStringField()); + + // enum + fieldId = findFieldId(metadata, "enumField"); + assertNull(getFieldValue(struct, fieldId)); + TstEnum e1 = TstEnum.E_ONE; + processor.setEnumField(struct, fieldId, e1); + assertEquals(TstEnum.E_ONE, struct.getEnumField()); + } + + @Test + public void testStructContainerFields() throws TException { + List<ThriftField> fields = ThriftField.fromNames( + Arrays.asList( + // List field + "i32List", + + // Set field + "stringSet", + + // Map field + "stringMap", + + // Struct field + "structField" + )); + + ThriftMetadata.ThriftStruct metadata = + ThriftMetadata.ThriftStruct.fromFields(TestStruct.class, fields); + ThriftStructProcessor processor = new ThriftStructProcessor(); + Object instance = processor.createNewStruct(metadata); + assertNotNull(instance); + assertTrue(instance instanceof TBase); + assertTrue(instance instanceof TestStruct); + + TestStruct struct = (TestStruct) instance; + + // list + TFieldIdEnum fieldId = findFieldId(metadata, "i32List"); + assertNull(getFieldValue(struct, fieldId)); + Integer[] ints = new Integer[] { 1, 2, 3 }; + List<Integer> intList = Arrays.asList(ints); + processor.setListField(struct, fieldId, intList); + assertArrayEquals(ints, struct.getI32List().toArray()); + + // set + fieldId = findFieldId(metadata, "stringSet"); + assertNull(getFieldValue(struct, fieldId)); + String[] strings = new String[] { "Hello", "World!" }; + Set<String> stringSet = new HashSet<>(Arrays.asList(strings)); + processor.setSetField(struct, fieldId, stringSet); + assertEquals(stringSet, struct.getStringSet()); + + // map + fieldId = findFieldId(metadata, "stringMap"); + assertNull(getFieldValue(struct, fieldId)); + Map<String, String> stringMap = new HashMap<>(); + stringMap.put("foo", "bar"); + stringMap.put("Hello", "World!"); + processor.setMapField(struct, fieldId, stringMap); + assertEquals(stringMap, struct.getStringMap()); + + // struct + fieldId = findFieldId(metadata, "structField"); + assertNull(getFieldValue(struct, fieldId)); + SmallStruct smallStruct = new SmallStruct(); + smallStruct.setI32Field(42); + SmallStruct smallStruct2 = new SmallStruct(); + smallStruct2.setI32Field(42); + processor.setStructField(struct, fieldId, smallStruct); + assertEquals(smallStruct2, struct.getStructField()); + } + + private TFieldIdEnum findFieldId(ThriftMetadata.ThriftStruct metadata, String fieldName) { + Collection<ThriftMetadata.ThriftObject> fields = metadata.fields.values(); + for (ThriftMetadata.ThriftObject field : fields) { + if (fieldName.equalsIgnoreCase(field.fieldId.getFieldName())) { + return field.fieldId; + } + } + + fail("Field not found: " + fieldName); + return null; + } + + private Object getFieldValue(TBase struct, TFieldIdEnum fieldId) { + TFieldIdEnum fieldRef = struct.fieldForId(fieldId.getThriftFieldId()); + if (struct.isSet(fieldRef)) { + return struct.getFieldValue(fieldRef); + } else { + return null; + } + } +} diff --git a/lib/java/test/org/apache/thrift/partial/ValidateTest.java b/lib/java/test/org/apache/thrift/partial/ValidateTest.java new file mode 100644 index 000000000..9d96844f8 --- /dev/null +++ b/lib/java/test/org/apache/thrift/partial/ValidateTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.thrift.partial; + +import org.apache.thrift.partial.ExceptionAsserts; +import org.apache.thrift.partial.TestData; + +import org.junit.Test; +import org.junit.runners.JUnit4; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; + +public class ValidateTest { + @Test + public void testCheckNotNull() { + String nonNullArg = "nonNullArg"; + String nullArg = null; + + // Should not throw. + Validate.checkNotNull(nonNullArg, "nonNullArg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'nullArg' must not be null", + () -> Validate.checkNotNull(nullArg, "nullArg")); + } + + @Test + public void testCheckPositiveInteger() { + int positiveArg = 1; + int zero = 0; + int negativeArg = -1; + + // Should not throw. + Validate.checkPositiveInteger(positiveArg, "positiveArg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'negativeArg' must be a positive integer", + () -> Validate.checkPositiveInteger(negativeArg, "negativeArg")); + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'zero' must be a positive integer", + () -> Validate.checkPositiveInteger(zero, "zero")); + } + + @Test + public void testCheckNotNegative() { + int positiveArg = 1; + int zero = 0; + int negativeArg = -1; + + // Should not throw. + Validate.checkNotNegative(zero, "zeroArg"); + Validate.checkNotNegative(positiveArg, "positiveArg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'negativeArg' must not be negative", + () -> Validate.checkNotNegative(negativeArg, "negativeArg")); + } + + @Test + public void testCheckRequired() { + // Should not throw. + Validate.checkRequired(true, "arg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' is required", + () -> Validate.checkRequired(false, "arg")); + } + + @Test + public void testCheckValid() { + // Should not throw. + Validate.checkValid(true, "arg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' is invalid", + () -> Validate.checkValid(false, "arg")); + } + + @Test + public void testCheckValidWithValues() { + String validValues = "foo, bar"; + + // Should not throw. + Validate.checkValid(true, "arg", validValues); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' is invalid. Valid values are: foo, bar", + () -> Validate.checkValid(false, "arg", validValues)); + } + + @Test + public void testCheckNotNullAndNotEmpty() { + // Should not throw. + Validate.checkNotNullAndNotEmpty(TestData.nonEmptyArray, "array"); + Validate.checkNotNullAndNotEmpty(TestData.nonEmptyByteArray, "array"); + Validate.checkNotNullAndNotEmpty(TestData.nonEmptyShortArray, "array"); + Validate.checkNotNullAndNotEmpty(TestData.nonEmptyIntArray, "array"); + Validate.checkNotNullAndNotEmpty(TestData.nonEmptyLongArray, "array"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'string' must not be empty", + () -> Validate.checkNotNullAndNotEmpty("", "string")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullByteArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyByteArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullShortArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyShortArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullIntArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyIntArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullLongArray, "array")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'array' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyLongArray, "array")); + } + + @Test + public void testCheckListNotNullAndNotEmpty() { + // Should not throw. + Validate.checkNotNullAndNotEmpty(TestData.validList, "list"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'list' must not be null", + () -> Validate.checkNotNullAndNotEmpty(TestData.nullList, "list")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'list' must have at least one element", + () -> Validate.checkNotNullAndNotEmpty(TestData.emptyList, "list")); + } + + @Test + public void testCheckNotNullAndNumberOfElements() { + // Should not throw. + Validate.checkNotNullAndNumberOfElements(Arrays.asList(1, 2, 3), 3, "arg"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' must not be null", + () -> Validate.checkNotNullAndNumberOfElements(null, 3, "arg") + ); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "Number of elements in 'arg' must be exactly 3, 2 given.", + () -> Validate.checkNotNullAndNumberOfElements(Arrays.asList(1, 2), 3, "arg") + ); + } + + @Test + public void testCheckValuesEqual() { + // Should not throw. + Validate.checkValuesEqual(1, "arg1", 1, "arg2"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg1' (1) must equal 'arg2' (2)", + () -> Validate.checkValuesEqual(1, "arg1", 2, "arg2")); + } + + @Test + public void testCheckIntegerMultiple() { + // Should not throw. + Validate.checkIntegerMultiple(10, "arg1", 5, "arg2"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg1' (10) must be an integer multiple of 'arg2' (3)", + () -> Validate.checkIntegerMultiple(10, "arg1", 3, "arg2")); + } + + @Test + public void testCheckGreater() { + // Should not throw. + Validate.checkGreater(10, "arg1", 5, "arg2"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg1' (5) must be greater than 'arg2' (10)", + () -> Validate.checkGreater(5, "arg1", 10, "arg2")); + } + + @Test + public void testCheckGreaterOrEqual() { + // Should not throw. + Validate.checkGreaterOrEqual(10, "arg1", 5, "arg2"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg1' (5) must be greater than or equal to 'arg2' (10)", + () -> Validate.checkGreaterOrEqual(5, "arg1", 10, "arg2")); + } + + @Test + public void testCheckWithinRange() { + // Should not throw. + Validate.checkWithinRange(10, "arg", 5, 15); + Validate.checkWithinRange(10.0, "arg", 5.0, 15.0); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' (5) must be within the range [10, 20]", + () -> Validate.checkWithinRange(5, "arg", 10, 20)); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'arg' (5.0) must be within the range [10.0, 20.0]", + () -> Validate.checkWithinRange(5.0, "arg", 10.0, 20.0)); + } + + @Test + public void testCheckPathExists() throws IOException { + Path tempFile = Files.createTempFile("foo", "bar"); + Path tempDir = tempFile.getParent(); + Path notFound = Paths.get("<not-found>"); + + // Should not throw. + Validate.checkPathExists(tempFile, "tempFile"); + Validate.checkPathExists(tempDir, "tempDir"); + + // Verify it throws. + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "'nullArg' must not be null", + () -> Validate.checkPathExists(null, "nullArg")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "Path notFound (<not-found>) does not exist", + () -> Validate.checkPathExists(notFound, "notFound")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "must point to a directory", + () -> Validate.checkPathExistsAsDir(tempFile, "tempFile")); + + ExceptionAsserts.assertThrows( + IllegalArgumentException.class, + "must point to a file", + () -> Validate.checkPathExistsAsFile(tempDir, "tempDir")); + } +} |