Skip to content

WIP Reduce fetch phase heap consumption #127386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -52,6 +53,7 @@
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportActionProxy;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
Expand Down Expand Up @@ -487,6 +489,7 @@ public static void registerRequestHandler(TransportService transportService, Sea
(request, channel, task) -> searchService.executeFetchPhase(
request,
(SearchShardTask) task,
maybeGetNetworkBuffer(transportService, channel),
new ChannelActionListener<>(channel)
)
);
Expand All @@ -503,7 +506,12 @@ public static void registerRequestHandler(TransportService transportService, Sea
TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new);

final TransportRequestHandler<ShardFetchRequest> shardFetchRequestHandler = (request, channel, task) -> searchService
.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
.executeFetchPhase(
request,
(SearchShardTask) task,
maybeGetNetworkBuffer(transportService, channel),
new ChannelActionListener<>(channel)
);
transportService.registerRequestHandler(
FETCH_ID_SCROLL_ACTION_NAME,
EsExecutors.DIRECT_EXECUTOR_SERVICE,
Expand Down Expand Up @@ -531,6 +539,12 @@ public static void registerRequestHandler(TransportService transportService, Sea
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NODE_NAME, true, CanMatchNodeResponse::new);
}

private static RecyclerBytesStreamOutput maybeGetNetworkBuffer(TransportService transportService, TransportChannel channel) {
return TransportService.DIRECT_RESPONSE_PROFILE.equals(channel.getProfileName()) || channel.compressionScheme() != null
? null
: transportService.newNetworkBytesStream();
}

private static Executor buildFreeContextExecutor(TransportService transportService) {
final ThrottledTaskRunner throttledTaskRunner = new ThrottledTaskRunner(
"free_context",
Expand Down
42 changes: 33 additions & 9 deletions server/src/main/java/org/elasticsearch/search/SearchHits.java
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,9 @@ public boolean isPooled() {
@Override
public void writeTo(StreamOutput out) throws IOException {
assert hasReferences();
final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
writeHeader(out);
out.writeArray(hits);
out.writeOptional(Lucene::writeSortFieldArray, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
writeFooter(out);
}

/**
Expand Down Expand Up @@ -260,6 +253,37 @@ private void deallocate() {
}
}

public void writeAndRelease(StreamOutput out) throws IOException {
boolean released = refCounted.decRef();
assert released;
writeHeader(out);
var hits = this.hits;
out.writeVInt(hits.length);
for (int i = 0; i < hits.length; i++) {
var h = hits[i];
hits[i] = null;
assert h != null;
h.writeTo(out);
h.decRef();
}
writeFooter(out);
}

private void writeFooter(StreamOutput out) throws IOException {
out.writeOptional(Lucene::writeSortFieldArray, sortFields);
out.writeOptionalString(collapseField);
out.writeOptionalArray(Lucene::writeSortValue, collapseValues);
}

private void writeHeader(StreamOutput out) throws IOException {
final boolean hasTotalHits = totalHits != null;
out.writeBoolean(hasTotalHits);
if (hasTotalHits) {
Lucene.writeTotalHits(out, totalHits);
}
out.writeFloat(maxScore);
}

@Override
public boolean hasReferences() {
return refCounted.hasReferences();
Expand Down
39 changes: 32 additions & 7 deletions server/src/main/java/org/elasticsearch/search/SearchService.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import org.elasticsearch.common.CheckedSupplier;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.lucene.Lucene;
import org.elasticsearch.common.settings.Setting;
Expand Down Expand Up @@ -95,7 +97,6 @@
import org.elasticsearch.search.dfs.DfsPhase;
import org.elasticsearch.search.dfs.DfsSearchResult;
import org.elasticsearch.search.fetch.FetchPhase;
import org.elasticsearch.search.fetch.FetchSearchResult;
import org.elasticsearch.search.fetch.QueryFetchSearchResult;
import org.elasticsearch.search.fetch.ScrollQueryFetchSearchResult;
import org.elasticsearch.search.fetch.ShardFetchRequest;
Expand Down Expand Up @@ -136,7 +137,9 @@
import org.elasticsearch.threadpool.Scheduler.Cancellable;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPool.Names;
import org.elasticsearch.transport.BytesTransportResponse;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.Transports;

import java.io.IOException;
Expand Down Expand Up @@ -1107,7 +1110,8 @@ private Executor getExecutor(IndexShard indexShard) {
public void executeFetchPhase(
InternalScrollSearchRequest request,
SearchShardTask task,
ActionListener<ScrollQueryFetchSearchResult> listener
RecyclerBytesStreamOutput networkBuffer,
ActionListener<TransportResponse> listener
) {
final LegacyReaderContext readerContext = (LegacyReaderContext) findReaderContext(request.contextId(), request);
final Releasable markAsUsed;
Expand Down Expand Up @@ -1139,8 +1143,14 @@ public void executeFetchPhase(
opsListener.onFailedQueryPhase(searchContext);
}
}
QueryFetchSearchResult fetchSearchResult = executeFetchPhase(readerContext, searchContext, afterQueryTime);
return new ScrollQueryFetchSearchResult(fetchSearchResult, searchContext.shardTarget());
var resp = executeFetchPhase(readerContext, searchContext, afterQueryTime);
if (networkBuffer == null) {
return new ScrollQueryFetchSearchResult(resp, searchContext.shardTarget());
}
searchContext.shardTarget().writeTo(networkBuffer);
resp.writeTo(networkBuffer);
resp.decRef();
return new BytesTransportResponse(new ReleasableBytesReference(networkBuffer.bytes(), networkBuffer));
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
logger.trace("Fetch phase failed", e);
Expand All @@ -1150,7 +1160,12 @@ public void executeFetchPhase(
}, wrapFailureListener(listener, readerContext, markAsUsed));
}

public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, ActionListener<FetchSearchResult> listener) {
public void executeFetchPhase(
ShardFetchRequest request,
CancellableTask task,
RecyclerBytesStreamOutput networkBuffer,
ActionListener<TransportResponse> listener
) {
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
Expand Down Expand Up @@ -1179,8 +1194,18 @@ public void executeFetchPhase(ShardFetchRequest request, CancellableTask task, A
}
var fetchResult = searchContext.fetchResult();
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
fetchResult.incRef();
return fetchResult;
if (networkBuffer == null) {
fetchResult.incRef();
return fetchResult;
}
try (networkBuffer) {
// no need to worry about releasing this instance safely before we write the first byte to it
// => the try-with-resources here is all we need to not leak any buffers
fetchResult.contextId.writeTo(networkBuffer);
fetchResult.consumeHits(networkBuffer);
networkBuffer.writeOptionalWriteable(fetchResult.profileResult());
return new BytesTransportResponse(networkBuffer.moveToBytesReference());
}
} catch (Exception e) {
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
// we handle the failure in the failure listener below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ public SearchHits hits() {
return hits;
}

public void consumeHits(StreamOutput out) throws IOException {
var hits = this.hits;
this.hits = null;
hits.writeAndRelease(out);
}

public FetchSearchResult initCounter() {
counter = 0;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ public void sendResponse(Exception exception) {
}
}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}

@Override
public TransportVersion getVersion() {
return channel.getVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ public void sendResponse(Exception exception) {
}
}

@Override
public Compression.Scheme compressionScheme() {
return compressionScheme;
}

@Override
public TransportVersion getVersion() {
return version;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ public interface TransportChannel {

void sendResponse(Exception exception);

/**
* Returns a suggestion about the desired compression scheme to use for sending the response when using {@link BytesTransportResponse}
* to bypass transport layer serialization and compression.
*
* @return the suggested compression scheme to use for responses or {@code null} when not using compression
*/
Compression.Scheme compressionScheme();

/**
* Returns the version of the data to communicate in this channel.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,11 @@ public String toString() {
}
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}

protected RemoteTransportException wrapInRemote(Exception e) {
return e instanceof RemoteTransportException remoteTransportException
? remoteTransportException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.junit.Before;
Expand Down Expand Up @@ -412,8 +413,8 @@ public void testSearchWhileIndexDeleted() throws InterruptedException {
intCursors,
null/* not a scroll */
);
PlainActionFuture<FetchSearchResult> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), listener);
PlainActionFuture<TransportResponse> listener = new PlainActionFuture<>();
service.executeFetchPhase(req, new SearchShardTask(123L, "", "", "", null, emptyMap()), null, listener);
listener.get();
if (useScroll) {
// have to free context since this test does not remove the index from IndicesService.
Expand Down Expand Up @@ -601,9 +602,10 @@ public RankShardResult buildRankFeatureShardResult(SearchHits hits, int shardId)
// execute fetch phase and perform any validations once we retrieve the response
// the difference in how we do assertions here is needed because once the transport service sends back the response
// it decrements the reference to the FetchSearchResult (through the ActionListener#respondAndRelease) and sets hits to null
PlainActionFuture<FetchSearchResult> fetchListener = new PlainActionFuture<>() {
PlainActionFuture<TransportResponse> fetchListener = new PlainActionFuture<>() {
@Override
public void onResponse(FetchSearchResult fetchSearchResult) {
public void onResponse(TransportResponse response) {
FetchSearchResult fetchSearchResult = (FetchSearchResult) response;
assertNotNull(fetchSearchResult);
assertNotNull(fetchSearchResult.hits());

Expand All @@ -624,7 +626,7 @@ public void onFailure(Exception e) {
throw new AssertionError("No failure should have been raised", e);
}
};
service.executeFetchPhase(fetchRequest, searchTask, fetchListener);
service.executeFetchPhase(fetchRequest, searchTask, null, fetchListener);
fetchListener.get();
} catch (Exception ex) {
if (queryResult != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ public void sendResponse(Exception exception) {
in.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return in.compressionScheme();
}

@Override
public TransportVersion getVersion() {
return in.getVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,11 @@ public void sendResponse(Exception exception) {
channel.sendResponse(exception);

}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}
}, task);
} else {
return actualHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ public String toString() {
}
});
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}
};

final TransportRequest copiedRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,9 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
listener.onFailure(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.elasticsearch.threadpool.FixedExecutorBuilder;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.transport.AbstractSimpleTransportTestCase;
import org.elasticsearch.transport.Compression;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportRequest;
Expand Down Expand Up @@ -721,6 +722,11 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
in.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return in.compressionScheme();
}
}

private final List<CircuitBreaker> breakers = Collections.synchronizedList(new ArrayList<>());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.transport.Compression;
import org.elasticsearch.transport.NoSuchRemoteClusterException;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportResponse;
Expand Down Expand Up @@ -467,6 +468,11 @@ public void sendResponse(TransportResponse response) {
public void sendResponse(Exception exception) {
channel.sendResponse(exception);
}

@Override
public Compression.Scheme compressionScheme() {
return channel.compressionScheme();
}
}, task)
);
}
Expand Down
Loading