diff --git a/Directory.Build.targets b/Directory.Build.targets index 1c5264bb2..920c2bf18 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -33,6 +33,7 @@ + diff --git a/src/JsonRpc/Connection.cs b/src/JsonRpc/Connection.cs index 31837e564..7bac4e3a6 100644 --- a/src/JsonRpc/Connection.cs +++ b/src/JsonRpc/Connection.cs @@ -7,7 +7,6 @@ namespace OmniSharp.Extensions.JsonRpc public class Connection : IDisposable { private readonly IInputHandler _inputHandler; - private readonly IRequestRouter _requestRouter; public Connection( Stream input, @@ -17,10 +16,9 @@ public Connection( IRequestRouter requestRouter, IResponseRouter responseRouter, ILoggerFactory loggerFactory, - ISerializer serializer) + ISerializer serializer, + int? concurrency) { - _requestRouter = requestRouter; - _inputHandler = new InputHandler( input, outputHandler, @@ -29,7 +27,8 @@ public Connection( requestRouter, responseRouter, loggerFactory, - serializer + serializer, + concurrency ); } diff --git a/src/JsonRpc/IScheduler.cs b/src/JsonRpc/IScheduler.cs index 2f1b5e447..e54d9ab15 100644 --- a/src/JsonRpc/IScheduler.cs +++ b/src/JsonRpc/IScheduler.cs @@ -1,4 +1,6 @@ using System; +using System.Reactive; +using System.Reactive.Linq; using System.Threading.Tasks; namespace OmniSharp.Extensions.JsonRpc @@ -6,6 +8,14 @@ namespace OmniSharp.Extensions.JsonRpc public interface IScheduler : IDisposable { void Start(); - void Add(RequestProcessType type, string name, Func request); + void Add(RequestProcessType type, string name, IObservable request); + } + + public static class SchedulerExtensions + { + public static void Add(this IScheduler scheduler, RequestProcessType type, string name, Func request) + { + scheduler.Add(type, name, Observable.FromAsync(request)); + } } } diff --git a/src/JsonRpc/InputHandler.cs b/src/JsonRpc/InputHandler.cs index fb67575b8..013439ede 100644 --- a/src/JsonRpc/InputHandler.cs +++ b/src/JsonRpc/InputHandler.cs @@ -37,7 +37,8 @@ public InputHandler( IRequestRouter requestRouter, IResponseRouter responseRouter, ILoggerFactory loggerFactory, - ISerializer serializer + ISerializer serializer, + int? concurrency ) { if (!input.CanRead) throw new ArgumentException($"must provide a readable stream for {nameof(input)}", nameof(input)); @@ -49,15 +50,15 @@ ISerializer serializer _responseRouter = responseRouter; _serializer = serializer; _logger = loggerFactory.CreateLogger(); - _scheduler = new ProcessScheduler(loggerFactory); + _scheduler = new ProcessScheduler(loggerFactory, concurrency); _inputThread = new Thread(ProcessInputStream) { IsBackground = true, Name = "ProcessInputStream" }; } public void Start() { + _scheduler.Start(); _outputHandler.Start(); _inputThread.Start(); - _scheduler.Start(); } // don't be async: We already allocated a seperate thread for this. diff --git a/src/JsonRpc/ProcessScheduler.cs b/src/JsonRpc/ProcessScheduler.cs index 7ad98c52d..7d6a6a1cc 100644 --- a/src/JsonRpc/ProcessScheduler.cs +++ b/src/JsonRpc/ProcessScheduler.cs @@ -1,6 +1,14 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.ComponentModel; +using System.Diagnostics; +using System.Linq; +using System.Reactive; +using System.Reactive.Concurrency; +using System.Reactive.Disposables; +using System.Reactive.Linq; +using System.Reactive.Subjects; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -9,122 +17,97 @@ namespace OmniSharp.Extensions.JsonRpc { public class ProcessScheduler : IScheduler { + private readonly int? _concurrency; private readonly ILogger _logger; - private readonly BlockingCollection<(RequestProcessType type, string name, Func request)> _queue; - private readonly CancellationTokenSource _cancel; - private readonly Thread _thread; + private readonly IObserver<(RequestProcessType type, string name, IObservable request)> _enqueue; + private readonly IObservable<(RequestProcessType type, string name, IObservable request)> _queue; + private bool _disposed = false; + private readonly CompositeDisposable _disposable = new CompositeDisposable(); + private readonly System.Reactive.Concurrency.IScheduler _scheduler; - public ProcessScheduler(ILoggerFactory loggerFactory) + public ProcessScheduler(ILoggerFactory loggerFactory, int? concurrency) : this(loggerFactory, concurrency, + new EventLoopScheduler( + _ => new Thread(_) {IsBackground = true, Name = "ProcessRequestQueue"})) { - _logger = loggerFactory.CreateLogger(); - _queue = new BlockingCollection<(RequestProcessType type, string name, Func request)>(); - _cancel = new CancellationTokenSource(); - _thread = new Thread(ProcessRequestQueue) { IsBackground = true, Name = "ProcessRequestQueue" }; } - public void Start() + internal ProcessScheduler(ILoggerFactory loggerFactory, int? concurrency, + System.Reactive.Concurrency.IScheduler scheduler) { - _thread.Start(); - } + _concurrency = concurrency; + _logger = loggerFactory.CreateLogger(); - public void Add(RequestProcessType type, string name, Func request) - { - _queue.Add((type, name, request)); + var subject = new Subject<(RequestProcessType type, string name, IObservable request)>(); + _disposable.Add(subject); + _enqueue = subject; + _scheduler = scheduler; + _queue = subject; } - private Task Start(Func request) + public void Start() { - var t = request(); - if (t.Status == TaskStatus.Created) // || t.Status = TaskStatus.WaitingForActivation ? - t.Start(); - return t; - } + var obs = Observable.Create(observer => { + var cd = new CompositeDisposable(); - private List RemoveCompleteTasks(List list) - { - if (list.Count == 0) return list; + var observableQueue = + new BehaviorSubject<(RequestProcessType type, ReplaySubject> observer)>(( + RequestProcessType.Serial, new ReplaySubject>(int.MaxValue))); + + cd.Add(_queue.Subscribe(item => { + if (observableQueue.Value.type != item.type) + { + observableQueue.Value.observer.OnCompleted(); + observableQueue.OnNext((item.type, new ReplaySubject>(int.MaxValue))); + } + + observableQueue.Value.observer.OnNext(HandleRequest(item.name, item.request)); + })); - var result = new List(); - foreach (var t in list) + cd.Add(observableQueue + .Select(item => { + var (type, replay) = item; + + if (type == RequestProcessType.Serial) + return replay.Concat(); + + return _concurrency.HasValue + ? replay.Merge(_concurrency.Value) + : replay.Merge(); + }) + .Concat() + .Subscribe(observer) + ); + + return cd; + }); + + _disposable.Add(obs + // .ObserveOn(_scheduler) + .Subscribe(_ => { }) + ); + + IObservable HandleRequest(string name, IObservable request) { - if (t.IsFaulted) - { - // TODO: Handle Fault - } - else if (!t.IsCompleted) - { - result.Add(t); - } + return request + .Catch(ex => Observable.Empty()) + .Catch(ex => { + _logger.LogCritical(Events.UnhandledException, ex, "Unhandled exception executing {Name}", + name); + return Observable.Empty(); + }); } - return result; } - public long _TestOnly_NonCompleteTaskCount = 0; - private void ProcessRequestQueue() + public void Add(RequestProcessType type, string name, IObservable request) { - // see https://github.com/OmniSharp/csharp-language-server-protocol/issues/4 - // no need to be async, because this thing already allocated a thread on it's own. - var token = _cancel.Token; - var waitables = new List(); - try - { - while (!token.IsCancellationRequested) - { - if (_queue.TryTake(out var item, Timeout.Infinite, token)) - { - var (type, name, request) = item; - try - { - if (type == RequestProcessType.Serial) - { - Task.WaitAll(waitables.ToArray(), token); - Start(request).Wait(token); - } - else if (type == RequestProcessType.Parallel) - { - waitables.Add(Start(request)); - } - else - throw new NotImplementedException("Only Serial and Parallel execution types can be handled currently"); - waitables = RemoveCompleteTasks(waitables); - Interlocked.Exchange(ref _TestOnly_NonCompleteTaskCount, waitables.Count); - } - catch (OperationCanceledException ex) when (ex.CancellationToken == token) - { - throw; - } - catch (Exception e) - { - // TODO: Should we rethrow or swallow? - // If an exception happens... the whole system could be in a bad state, hence this throwing currently. - _logger.LogCritical(Events.UnhandledException, e, "Unhandled exception executing {Name}", name); - throw; - } - } - } - } - catch (OperationCanceledException ex) when (ex.CancellationToken == token) - { - // OperationCanceledException - The CancellationToken has been canceled. - Task.WaitAll(waitables.ToArray(), TimeSpan.FromMilliseconds(1000)); - var keeponrunning = RemoveCompleteTasks(waitables); - Interlocked.Exchange(ref _TestOnly_NonCompleteTaskCount, keeponrunning.Count); - keeponrunning.ForEach((t) => - { - // TODO: There is no way to abort a Task. As we don't construct the tasks, we can do nothing here - // Option is: change the task factory "Func request" to a "Func request" - }); - } + _enqueue.OnNext((type, name, request)); } - private bool _disposed = false; public void Dispose() { if (_disposed) return; _disposed = true; - _cancel.Cancel(); - _thread.Join(); - _cancel.Dispose(); + _disposable.Dispose(); } } } diff --git a/src/JsonRpc/RequestRouterBase.cs b/src/JsonRpc/RequestRouterBase.cs index 144be2d40..4470b9f80 100644 --- a/src/JsonRpc/RequestRouterBase.cs +++ b/src/JsonRpc/RequestRouterBase.cs @@ -161,7 +161,7 @@ public virtual async Task RouteRequest(TDescriptor descriptor, Re return new JsonRpc.Client.Response(request.Id, responseValue, request); } - catch (TaskCanceledException) + catch (OperationCanceledException) { _logger.LogDebug("Request {Id} was cancelled", id); return new RequestCancelled(); diff --git a/src/Protocol/Document/Server/ICompletionHandler.cs b/src/Protocol/Document/Server/ICompletionHandler.cs index 08769ffd7..14f302483 100644 --- a/src/Protocol/Document/Server/ICompletionHandler.cs +++ b/src/Protocol/Document/Server/ICompletionHandler.cs @@ -12,7 +12,7 @@ namespace OmniSharp.Extensions.LanguageServer.Protocol.Server [Parallel, Method(DocumentNames.Completion)] public interface ICompletionHandler : IJsonRpcRequestHandler, IRegistration, ICapability { } - [Serial, Method(DocumentNames.CompletionResolve)] + [Parallel, Method(DocumentNames.CompletionResolve)] public interface ICompletionResolveHandler : ICanBeResolvedHandler { } public abstract class CompletionHandler : ICompletionHandler, ICompletionResolveHandler diff --git a/src/Server/LanguageServer.cs b/src/Server/LanguageServer.cs index e6ef666ef..9b08ec08a 100644 --- a/src/Server/LanguageServer.cs +++ b/src/Server/LanguageServer.cs @@ -54,6 +54,7 @@ public class LanguageServer : ILanguageServer, IInitializeHandler, IInitializedH private readonly SupportedCapabilities _supportedCapabilities; private Task _initializingTask; private readonly ILanguageServerConfiguration _configuration; + private readonly int? _concurrency; public static Task From(Action optionsAction) { @@ -118,7 +119,8 @@ public static ILanguageServer PreInit(LanguageServerOptions options) options.AddDefaultLoggingProvider, options.ProgressManager, options.ServerInfo, - options.ConfigurationBuilderAction + options.ConfigurationBuilderAction, + options.Concurrency ); } @@ -143,7 +145,8 @@ internal LanguageServer( bool addDefaultLoggingProvider, ProgressManager progressManager, ServerInfo serverInfo, - Action configurationBuilderAction) + Action configurationBuilderAction, + int? concurrency) { var outputHandler = new OutputHandler(output, serializer); @@ -234,7 +237,7 @@ internal LanguageServer( var requestRouter = _serviceProvider.GetRequiredService>(); _responseRouter = _serviceProvider.GetRequiredService(); - _connection = ActivatorUtilities.CreateInstance(_serviceProvider, input); + _connection = ActivatorUtilities.CreateInstance(_serviceProvider, input, concurrency); _exitHandler = new ServerExitHandler(_shutdownHandler); diff --git a/src/Server/LanguageServerOptions.cs b/src/Server/LanguageServerOptions.cs index 0e15eb4a2..8c977aed2 100644 --- a/src/Server/LanguageServerOptions.cs +++ b/src/Server/LanguageServerOptions.cs @@ -38,6 +38,7 @@ public LanguageServerOptions() internal Action LoggingBuilderAction { get; set; } = new Action(_ => { }); internal Action ConfigurationBuilderAction { get; set; } = new Action(_ => { }); internal bool AddDefaultLoggingProvider { get; set; } + public int? Concurrency { get; set; } internal readonly List InitializeDelegates = new List(); internal readonly List InitializedDelegates = new List(); diff --git a/src/Server/LanguageServerOptionsExtensions.cs b/src/Server/LanguageServerOptionsExtensions.cs index 417138d2a..674ec3b74 100644 --- a/src/Server/LanguageServerOptionsExtensions.cs +++ b/src/Server/LanguageServerOptionsExtensions.cs @@ -86,6 +86,18 @@ public static LanguageServerOptions WithServerInfo(this LanguageServerOptions op return options; } + /// + /// Set maximum number of allowed parallel actions + /// + /// + /// + /// + public static LanguageServerOptions WithConcurrency(this LanguageServerOptions options, int? concurrency) + { + options.Concurrency = concurrency; + return options; + } + public static LanguageServerOptions OnInitialize(this LanguageServerOptions options, InitializeDelegate @delegate) { options.InitializeDelegates.Add(@delegate); diff --git a/test/Directory.Build.targets b/test/Directory.Build.targets index 5bcafa0b6..eff0c929b 100644 --- a/test/Directory.Build.targets +++ b/test/Directory.Build.targets @@ -14,6 +14,7 @@ + diff --git a/test/JsonRpc.Tests/DapInputHandlerTests.cs b/test/JsonRpc.Tests/DapInputHandlerTests.cs index 0636bbe9a..710658822 100644 --- a/test/JsonRpc.Tests/DapInputHandlerTests.cs +++ b/test/JsonRpc.Tests/DapInputHandlerTests.cs @@ -42,7 +42,8 @@ private static InputHandler NewHandler( requestRouter, responseRouter, Substitute.For(), - new DapSerializer()); + new DapSerializer(), + null); handler.Start(); cts.Wait(); Task.Delay(10).Wait(); diff --git a/test/JsonRpc.Tests/InputHandlerTests.cs b/test/JsonRpc.Tests/InputHandlerTests.cs index 87adb7ed5..7e6697903 100644 --- a/test/JsonRpc.Tests/InputHandlerTests.cs +++ b/test/JsonRpc.Tests/InputHandlerTests.cs @@ -44,7 +44,8 @@ private static InputHandler NewHandler( requestRouter, responseRouter, Substitute.For(), - new JsonRpcSerializer()); + new JsonRpcSerializer(), + null); handler.Start(); cts.Wait(); Task.Delay(10).Wait(); @@ -265,7 +266,7 @@ public void ShouldHandleResponse() } [Fact] - public async Task ShouldCancelRequest() + public void ShouldCancelRequest() { var inputStream = new MemoryStream(Encoding.ASCII.GetBytes("Content-Length: 2\r\n\r\n{}")); var outputHandler = Substitute.For(); diff --git a/test/JsonRpc.Tests/ProcessSchedulerTests.cs b/test/JsonRpc.Tests/ProcessSchedulerTests.cs index 6bec32ad2..e1963970a 100644 --- a/test/JsonRpc.Tests/ProcessSchedulerTests.cs +++ b/test/JsonRpc.Tests/ProcessSchedulerTests.cs @@ -5,8 +5,13 @@ using Xunit; using FluentAssertions; using System.Collections.Generic; +using System.Reactive; +using System.Reactive.Linq; +using Microsoft.Reactive.Testing; using OmniSharp.Extensions.JsonRpc; using Xunit.Abstractions; +using Xunit.Sdk; +using static Microsoft.Reactive.Testing.ReactiveTest; namespace JsonRpc.Tests { @@ -17,163 +22,281 @@ public ProcessSchedulerTests(ITestOutputHelper testOutputHelper) _testOutputHelper = testOutputHelper; } - private const int SLEEPTIME_MS = 20; - private const int ALONGTIME_MS = 500; private readonly ITestOutputHelper _testOutputHelper; class AllRequestProcessTypes : TheoryData { public override IEnumerator GetEnumerator() { - yield return new object[] { RequestProcessType.Serial }; - yield return new object[] { RequestProcessType.Parallel }; + yield return new object[] {RequestProcessType.Serial}; + yield return new object[] {RequestProcessType.Parallel}; } } [Theory, ClassData(typeof(AllRequestProcessTypes))] public void ShouldScheduleCompletedTask(RequestProcessType type) { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(1); - s.Start(); - s.Add(type, "bogus", () => - { - done.Signal(); - return Task.CompletedTask; - }); - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + s.Add(type, "bogus", testObservable.Do(testObserver)); + + testScheduler.AdvanceTo(50); + + testObservable.Subscriptions.Count.Should().Be(1); + + testScheduler.AdvanceTo(101); + + testObservable.Subscriptions.Count.Should().Be(1); + testObserver.Messages.Should().Contain(z => z.Value.Kind == NotificationKind.OnNext); + testObserver.Messages.Should().Contain(z => z.Value.Kind == NotificationKind.OnCompleted); } [Fact] - public void ShouldScheduleAwaitableTask() + public void ShouldScheduleSerialInOrder() { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(1); - s.Start(); - s.Add(RequestProcessType.Serial, "bogus", async () => - { - await Task.Yield(); - done.Signal(); - }); - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + for (var i = 0; i < 8; i++) + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(8); + testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted).Should() + .ContainInOrder( + OnNext(100, Unit.Default), + OnNext(200, Unit.Default), + OnNext(300, Unit.Default), + OnNext(400, Unit.Default), + OnNext(500, Unit.Default), + OnNext(600, Unit.Default), + OnNext(700, Unit.Default), + OnNext(800, Unit.Default) + ); } [Fact] - public void ShouldScheduleConstructedTask() + public void ShouldScheduleParallelInParallel() { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(1); - s.Start(); - s.Add(RequestProcessType.Serial, "bogus", () => - { - return new Task(() => - { - done.Signal(); - }); - }); - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + for (var i = 0; i < 8; i++) + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(8); + testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted).Should() + .ContainInOrder( + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default) + ); } [Fact] - public void ShouldScheduleSerialInOrder() + public void ShouldScheduleMixed() { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(3); // 3x s.Add - var running = 0; - var peek = 0; - - Func HandlePeek = async () => - { - var p = Interlocked.Increment(ref running); - lock (this) peek = Math.Max(peek, p); - await Task.Delay(SLEEPTIME_MS); // give a different HandlePeek task a chance to run - Interlocked.Decrement(ref running); - done.Signal(); - }; - - s.Start(); - for (var i = 0; i < done.CurrentCount; i++) - s.Add(RequestProcessType.Serial, "bogus", HandlePeek); - - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - running.Should().Be(0, because: "all tasks have to run normally"); - peek.Should().Be(1, because: "all tasks must not overlap"); - s.Dispose(); - Interlocked.Read(ref ((ProcessScheduler)s)._TestOnly_NonCompleteTaskCount).Should().Be(0, because: "the scheduler must not wait for tasks to complete after disposal"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(8); + testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted).Should() + .ContainInOrder( + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(200, Unit.Default), + OnNext(300, Unit.Default), + OnNext(300, Unit.Default), + OnNext(400, Unit.Default) + ); } [Fact] - public void ShouldScheduleParallelInParallel() + public void ShouldScheduleSerial() { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(8); // 8x s.Add - var running = 0; - var peek = 0; - - Func HandlePeek = async () => - { - var p = Interlocked.Increment(ref running); - lock (this) peek = Math.Max(peek, p); - await Task.Delay(SLEEPTIME_MS); // give a different HandlePeek task a chance to run - Interlocked.Decrement(ref running); - done.Signal(); - }; - - s.Start(); - for (var i = 0; i < done.CurrentCount; i++) - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - running.Should().Be(0, because: "all tasks have to run normally"); - peek.Should().BeGreaterThan(3, because: "a lot of tasks should overlap"); - s.Dispose(); - Interlocked.Read(ref ((ProcessScheduler)s)._TestOnly_NonCompleteTaskCount).Should().Be(0, because: "the scheduler must not wait for tasks to complete after disposal"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(4); + testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted).Should() + .ContainInOrder( + OnNext(100, Unit.Default), + OnNext(200, Unit.Default), + OnNext(300, Unit.Default), + OnNext(400, Unit.Default) + ); } [Fact] - public void ShouldScheduleMixed() + public void ShouldScheduleWithConcurrency() { - using (IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper))) - { - var done = new CountdownEvent(8); // 8x s.Add - var running = 0; - var peek = 0; - - Func HandlePeek = async () => - { - var p = Interlocked.Increment(ref running); - lock (this) peek = Math.Max(peek, p); - await Task.Delay(SLEEPTIME_MS); // give a different HandlePeek task a chance to run - Interlocked.Decrement(ref running); - done.Signal(); - }; - - s.Start(); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Serial, "bogus", HandlePeek); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Parallel, "bogus", HandlePeek); - s.Add(RequestProcessType.Serial, "bogus", HandlePeek); - - done.Wait(ALONGTIME_MS).Should().Be(true, because: "all tasks have to run"); - running.Should().Be(0, because: "all tasks have to run normally"); - peek.Should().BeGreaterThan(2, because: "some tasks should overlap"); - s.Dispose(); - Interlocked.Read(ref ((ProcessScheduler)s)._TestOnly_NonCompleteTaskCount).Should().Be(0, because: "the scheduler must not wait for tasks to complete after disposal"); - } + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), 3); + + s.Start(); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Parallel, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(8); + testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted).Should() + .ContainInOrder( + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(100, Unit.Default), + OnNext(200, Unit.Default), + OnNext(300, Unit.Default), + OnNext(400, Unit.Default), + OnNext(400, Unit.Default), + OnNext(500, Unit.Default) + ); + } + + [Fact] + public void Should_Handle_Cancelled_Tasks() + { + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var errorObservable = testScheduler.CreateColdObservable( + OnError(100, new TaskCanceledException(), Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "somethingelse", errorObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(2); + var messages = testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted) + .ToArray(); + + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnNext && x.Time == 100); + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnError && x.Time == 200 && x.Value.Exception is OperationCanceledException); + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnNext && x.Time == 300); + } + + [Fact] + public void Should_Handle_Exceptions_Tasks() + { + var testScheduler = new TestScheduler(); + var testObservable = testScheduler.CreateColdObservable( + OnNext(100, Unit.Default), + OnCompleted(100, Unit.Default) + ); + var errorObservable = testScheduler.CreateColdObservable( + OnError(100, new NotSameException(), Unit.Default) + ); + var testObserver = testScheduler.CreateObserver(); + + using IScheduler s = new ProcessScheduler(new TestLoggerFactory(_testOutputHelper), null); + + s.Start(); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "somethingelse", errorObservable.Do(testObserver)); + s.Add(RequestProcessType.Serial, "bogus", testObservable.Do(testObserver)); + + testScheduler.Start(); + + testObservable.Subscriptions.Count.Should().Be(2); + var messages = testObserver.Messages + .Where(z => z.Value.Kind != NotificationKind.OnCompleted) + .ToArray(); + + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnNext && x.Time == 100); + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnError && x.Time == 200 && x.Value.Exception is NotSameException); + messages.Should().Contain(x => x.Value.Kind == NotificationKind.OnNext && x.Time == 300); + } } }