Skip to content

Fix decryption when client is using AutoEncryptionSettings #4439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4432-SNAPSHOT</version>
<packaging>pom</packaging>

<name>Spring Data MongoDB</name>
Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-benchmarks/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4432-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb-distribution/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4432-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
2 changes: 1 addition & 1 deletion spring-data-mongodb/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<parent>
<groupId>org.springframework.data</groupId>
<artifactId>spring-data-mongodb-parent</artifactId>
<version>4.2.0-SNAPSHOT</version>
<version>4.2.x-4432-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -25,6 +26,7 @@
import org.bson.BsonDocument;
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.conversions.Bson;
import org.bson.types.Binary;
import org.springframework.core.CollectionFactory;
import org.springframework.data.mongodb.core.convert.MongoConversionContext;
Expand Down Expand Up @@ -63,7 +65,7 @@ public MongoEncryptionConverter(Encryption<BsonValue, BsonBinary> encryption, En
public Object read(Object value, MongoConversionContext context) {

Object decrypted = EncryptingConverter.super.read(value, context);
return decrypted instanceof BsonValue ? BsonUtils.toJavaType((BsonValue) decrypted) : decrypted;
return decrypted instanceof BsonValue bsonValue ? BsonUtils.toJavaType(bsonValue) : decrypted;
}

@Override
Expand All @@ -87,36 +89,56 @@ public Object decrypt(Object encryptedValue, EncryptionContext context) {
}

MongoPersistentProperty persistentProperty = getProperty(context);

if (getProperty(context).isCollectionLike() && decryptedValue instanceof Iterable<?> iterable) {

int size = iterable instanceof Collection<?> c ? c.size() : 10;

if (!persistentProperty.isEntity()) {
Collection<Object> collection = CollectionFactory.createCollection(persistentProperty.getType(), size);
iterable.forEach(it -> collection.add(BsonUtils.toJavaType((BsonValue) it)));
iterable.forEach(it -> {
if(it instanceof BsonValue bsonValue) {
collection.add(BsonUtils.toJavaType(bsonValue));
} else {
collection.add(context.read(it, persistentProperty.getActualType()));
}
});

return collection;
} else {
Collection<Object> collection = CollectionFactory.createCollection(persistentProperty.getType(), size);
iterable.forEach(it -> {
collection.add(context.read(BsonUtils.toJavaType((BsonValue) it), persistentProperty.getActualType()));
if(it instanceof BsonValue bsonValue) {
collection.add(context.read(BsonUtils.toJavaType(bsonValue), persistentProperty.getActualType()));
} else {
collection.add(context.read(it, persistentProperty.getActualType()));
}
});
return collection;
}
}

if (!persistentProperty.isEntity() && decryptedValue instanceof BsonValue bsonValue) {
if (persistentProperty.isMap() && persistentProperty.getType() != Document.class) {
return new LinkedHashMap<>((Document) BsonUtils.toJavaType(bsonValue));

if (!persistentProperty.isEntity() && persistentProperty.isMap()) {
if(persistentProperty.getType() != Document.class) {
if(decryptedValue instanceof BsonValue bsonValue) {
return new LinkedHashMap<>((Document) BsonUtils.toJavaType(bsonValue));
}
if(decryptedValue instanceof Document document) {
return new LinkedHashMap<>(document);
}
if(decryptedValue instanceof Map map) {
return map;
}
}
return BsonUtils.toJavaType(bsonValue);
}

if (persistentProperty.isEntity() && decryptedValue instanceof BsonDocument bsonDocument) {
return context.read(BsonUtils.toJavaType(bsonDocument), persistentProperty.getTypeInformation().getType());
}

if (persistentProperty.isEntity() && decryptedValue instanceof Document document) {
return context.read(document, persistentProperty.getTypeInformation().getType());
}

return decryptedValue;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,33 @@
*/
package org.springframework.data.mongodb.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import java.util.function.Function;
import java.util.stream.StreamSupport;

import org.bson.*;
import org.bson.codecs.Codec;
import org.bson.codecs.DocumentCodec;
import org.bson.codecs.EncoderContext;
import org.bson.codecs.configuration.CodecConfigurationException;
import org.bson.codecs.configuration.CodecRegistry;
import org.bson.conversions.Bson;
import org.bson.json.JsonParseException;
import org.bson.types.Binary;
import org.bson.types.Decimal128;
import org.bson.types.ObjectId;
import org.springframework.core.convert.converter.Converter;
import org.springframework.data.mongodb.CodecRegistryProvider;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.CollectionUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -103,7 +110,7 @@ public static Map<String, Object> asMap(@Nullable Bson bson, CodecRegistry codec
return dbo.toMap();
}

return new Document((Map) bson.toBsonDocument(Document.class, codecRegistry));
return new Document(bson.toBsonDocument(Document.class, codecRegistry));
}

/**
Expand Down Expand Up @@ -321,6 +328,20 @@ public static Object toJavaType(BsonValue value) {
* @since 3.0
*/
public static BsonValue simpleToBsonValue(Object source) {
return simpleToBsonValue(source, MongoClientSettings.getDefaultCodecRegistry());
}

/**
* Convert a given simple value (eg. {@link String}, {@link Long}) to its corresponding {@link BsonValue}.
*
* @param source must not be {@literal null}.
* @param codecRegistry The {@link CodecRegistry} used as a fallback to convert types using native {@link Codec}. Must
* not be {@literal null}.
* @return the corresponding {@link BsonValue} representation.
* @throws IllegalArgumentException if {@literal source} does not correspond to a {@link BsonValue} type.
* @since 4.2
*/
public static BsonValue simpleToBsonValue(Object source, CodecRegistry codecRegistry) {

if (source instanceof BsonValue bsonValue) {
return bsonValue;
Expand Down Expand Up @@ -358,12 +379,30 @@ public static BsonValue simpleToBsonValue(Object source) {
return new BsonDouble(floatValue);
}

if(source instanceof Binary binary) {
if (source instanceof Binary binary) {
return new BsonBinary(binary.getType(), binary.getData());
}

throw new IllegalArgumentException(String.format("Unable to convert %s (%s) to BsonValue.", source,
source != null ? source.getClass().getName() : "null"));
if (source instanceof Date date) {
new BsonDateTime(date.getTime());
}

try {

Object value = source;
if (ClassUtils.isPrimitiveArray(source.getClass())) {
value = CollectionUtils.arrayToList(source);
}

Codec codec = codecRegistry.get(value.getClass());
BsonCapturingWriter writer = new BsonCapturingWriter(value.getClass());
codec.encode(writer, value,
ObjectUtils.isArray(value) || value instanceof Collection<?> ? EncoderContext.builder().build() : null);
return writer.getCapturedValue();
} catch (CodecConfigurationException e) {
throw new IllegalArgumentException(
String.format("Unable to convert %s to BsonValue.", source != null ? source.getClass().getName() : "null"));
}
}

/**
Expand Down Expand Up @@ -669,7 +708,7 @@ private static String toJson(@Nullable Object value) {

if (value instanceof Collection<?> collection) {
return toString(collection);
} else if (value instanceof Map<?,?> map) {
} else if (value instanceof Map<?, ?> map) {
return toString(map);
} else if (ObjectUtils.isArray(value)) {
return toString(Arrays.asList(ObjectUtils.toObjectArray(value)));
Expand Down Expand Up @@ -708,4 +747,162 @@ private static <T> String iterableToDelimitedString(Iterable<T> source, String p

return joiner.toString();
}

private static class BsonCapturingWriter extends AbstractBsonWriter {

List<BsonValue> values = new ArrayList<>(0);

public BsonCapturingWriter(Class<?> type) {
super(new BsonWriterSettings());
if (ClassUtils.isAssignable(Map.class, type)) {
setContext(new Context(null, BsonContextType.DOCUMENT));
} else if (ClassUtils.isAssignable(List.class, type) || type.isArray()) {
setContext(new Context(null, BsonContextType.ARRAY));
} else {
setContext(new Context(null, BsonContextType.DOCUMENT));
}
}

BsonValue getCapturedValue() {

if (values.isEmpty()) {
return null;
}
if (!getContext().getContextType().equals(BsonContextType.ARRAY)) {
return values.get(0);
}

return new BsonArray(values);
}

@Override
protected void doWriteStartDocument() {

}

@Override
protected void doWriteEndDocument() {

}

@Override
public void writeStartArray() {
setState(State.VALUE);
}

@Override
public void writeEndArray() {
setState(State.NAME);
}

@Override
protected void doWriteStartArray() {

}

@Override
protected void doWriteEndArray() {

}

@Override
protected void doWriteBinaryData(BsonBinary value) {
values.add(value);
}

@Override
protected void doWriteBoolean(boolean value) {
values.add(BsonBoolean.valueOf(value));
}

@Override
protected void doWriteDateTime(long value) {
values.add(new BsonDateTime(value));
}

@Override
protected void doWriteDBPointer(BsonDbPointer value) {
values.add(value);
}

@Override
protected void doWriteDouble(double value) {
values.add(new BsonDouble(value));
}

@Override
protected void doWriteInt32(int value) {
values.add(new BsonInt32(value));
}

@Override
protected void doWriteInt64(long value) {
values.add(new BsonInt64(value));
}

@Override
protected void doWriteDecimal128(Decimal128 value) {
values.add(new BsonDecimal128(value));
}

@Override
protected void doWriteJavaScript(String value) {
values.add(new BsonJavaScript(value));
}

@Override
protected void doWriteJavaScriptWithScope(String value) {
values.add(new BsonJavaScriptWithScope(value, null));
}

@Override
protected void doWriteMaxKey() {

}

@Override
protected void doWriteMinKey() {

}

@Override
protected void doWriteNull() {
values.add(new BsonNull());
}

@Override
protected void doWriteObjectId(ObjectId value) {
values.add(new BsonObjectId(value));
}

@Override
protected void doWriteRegularExpression(BsonRegularExpression value) {
values.add(value);
}

@Override
protected void doWriteString(String value) {
values.add(new BsonString(value));
}

@Override
protected void doWriteSymbol(String value) {
values.add(new BsonSymbol(value));
}

@Override
protected void doWriteTimestamp(BsonTimestamp value) {
values.add(value);
}

@Override
protected void doWriteUndefined() {
values.add(new BsonUndefined());
}

@Override
public void flush() {
values.clear();
}
}
}
Loading