bug fix - remove port bytes during port forwarding (#233)
* fix a race condition. when multiple call to GetStream happens around the same time, on the same inputIndex, a race condition will cause this.buffers.Add() to throw exception. * add WebSocket server certificate validation support for net 452 * port forwarding bug fix In StreamDemuxer, if the buffer is created before connection is established, the port bytes are not removed when the bytes are delivering to the client. Repro code looks like (the key is to call remoteStreams.Start AFTER GetStream): var ws = await kubernetesClient.WebSocketNamespacedPodPortForwardAsync (... ) var remoteStreams = new StreamDemuxer(ws); var stream = remoteStreams.GetStream(0, 0); remoteStreams.Start(); This change filters out the port bytes which are the 2nd and 3rd bytes sent in all cases. * incorporate review feedbacks. add an enum StreamType to StreamDemuxer so it knows whether / how the data stream should be handled. * add tests, fix skip bytes scenario for multiple streams * add more tests for verifying content. * simplify code a bit
This commit is contained in:
@@ -26,6 +26,7 @@ namespace k8s
|
|||||||
private readonly WebSocket webSocket;
|
private readonly WebSocket webSocket;
|
||||||
private readonly Dictionary<byte, ByteBuffer> buffers = new Dictionary<byte, ByteBuffer>();
|
private readonly Dictionary<byte, ByteBuffer> buffers = new Dictionary<byte, ByteBuffer>();
|
||||||
private readonly CancellationTokenSource cts = new CancellationTokenSource();
|
private readonly CancellationTokenSource cts = new CancellationTokenSource();
|
||||||
|
private readonly StreamType streamType;
|
||||||
private Task runLoop;
|
private Task runLoop;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
@@ -34,8 +35,12 @@ namespace k8s
|
|||||||
/// <param name="webSocket">
|
/// <param name="webSocket">
|
||||||
/// A <see cref="WebSocket"/> which contains a multiplexed stream, such as the <see cref="WebSocket"/> returned by the exec or attach commands.
|
/// A <see cref="WebSocket"/> which contains a multiplexed stream, such as the <see cref="WebSocket"/> returned by the exec or attach commands.
|
||||||
/// </param>
|
/// </param>
|
||||||
public StreamDemuxer(WebSocket webSocket)
|
/// <param name="streamType">
|
||||||
|
/// A <see cref="StreamType"/> specifies the type of the stream.
|
||||||
|
/// </param>
|
||||||
|
public StreamDemuxer(WebSocket webSocket, StreamType streamType = StreamType.RemoteCommand)
|
||||||
{
|
{
|
||||||
|
this.streamType = streamType;
|
||||||
this.webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
|
this.webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -179,7 +184,8 @@ namespace k8s
|
|||||||
{
|
{
|
||||||
// Get a 1KB buffer
|
// Get a 1KB buffer
|
||||||
byte[] buffer = ArrayPool<byte>.Shared.Rent(1024 * 1024);
|
byte[] buffer = ArrayPool<byte>.Shared.Rent(1024 * 1024);
|
||||||
|
// This maps remembers bytes skipped for each stream.
|
||||||
|
Dictionary<byte, int> streamBytesToSkipMap = new Dictionary<byte, int>();
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var segment = new ArraySegment<byte>(buffer);
|
var segment = new ArraySegment<byte>(buffer);
|
||||||
@@ -202,11 +208,35 @@ namespace k8s
|
|||||||
|
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
if (this.buffers.ContainsKey(streamIndex))
|
int bytesToSkip = 0;
|
||||||
|
if (!streamBytesToSkipMap.TryGetValue(streamIndex, out bytesToSkip))
|
||||||
{
|
{
|
||||||
this.buffers[streamIndex].Write(buffer, extraByteCount, result.Count - extraByteCount);
|
// When used in port-forwarding, the first 2 bytes from the web socket is port bytes, skip.
|
||||||
|
// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/server/portforward/websocket.go
|
||||||
|
bytesToSkip = this.streamType == StreamType.PortForward ? 2 : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int bytesCount = result.Count - extraByteCount;
|
||||||
|
if (bytesToSkip > 0 && bytesToSkip >= bytesCount)
|
||||||
|
{
|
||||||
|
// skip the entire data.
|
||||||
|
bytesToSkip -= bytesCount;
|
||||||
|
extraByteCount += bytesCount;
|
||||||
|
bytesCount = 0;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
bytesCount -= bytesToSkip;
|
||||||
|
extraByteCount += bytesToSkip;
|
||||||
|
bytesToSkip = 0;
|
||||||
|
|
||||||
|
if (this.buffers.ContainsKey(streamIndex))
|
||||||
|
{
|
||||||
|
this.buffers[streamIndex].Write(buffer, extraByteCount, bytesCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
streamBytesToSkipMap[streamIndex] = bytesToSkip;
|
||||||
|
|
||||||
if (result.EndOfMessage == true)
|
if (result.EndOfMessage == true)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
|
|||||||
22
src/KubernetesClient/StreamType.cs
Normal file
22
src/KubernetesClient/StreamType.cs
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
using System;
|
||||||
|
|
||||||
|
namespace k8s
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// When creating a <see cref="StreamDemuxer"/> object, specify <see cref="StreamType"/> to properly handle
|
||||||
|
/// the underlying communication.
|
||||||
|
/// </summary>
|
||||||
|
public enum StreamType
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// This <see cref="StreamDemuxer"/> object is used to stream a remote command or attach to a remote
|
||||||
|
/// container.
|
||||||
|
/// </summary>
|
||||||
|
RemoteCommand,
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// This <see cref="StreamDemuxer"/> object is used in port forwarding.
|
||||||
|
/// </summary>
|
||||||
|
PortForward
|
||||||
|
}
|
||||||
|
}
|
||||||
140
tests/KubernetesClient.Tests/Mock/MockWebSocket.cs
Normal file
140
tests/KubernetesClient.Tests/Mock/MockWebSocket.cs
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Concurrent;
|
||||||
|
using System.Net.WebSockets;
|
||||||
|
using System.Threading;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
|
namespace k8s.tests.Mock
|
||||||
|
{
|
||||||
|
public class MockWebSocket : WebSocket
|
||||||
|
{
|
||||||
|
private WebSocketCloseStatus? closeStatus = null;
|
||||||
|
private string closeStatusDescription;
|
||||||
|
private WebSocketState state;
|
||||||
|
private string subProtocol;
|
||||||
|
private ConcurrentQueue<MessageData> receiveBuffers = new ConcurrentQueue<MessageData>();
|
||||||
|
private AutoResetEvent receiveEvent = new AutoResetEvent(false);
|
||||||
|
|
||||||
|
public MockWebSocket(string subProtocol = null)
|
||||||
|
{
|
||||||
|
this.subProtocol = subProtocol;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void SetState(WebSocketState state)
|
||||||
|
{
|
||||||
|
this.state = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
public EventHandler<MessageDataEventArgs> MessageSent;
|
||||||
|
|
||||||
|
public Task InvokeReceiveAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage)
|
||||||
|
{
|
||||||
|
this.receiveBuffers.Enqueue(new MessageData()
|
||||||
|
{
|
||||||
|
Buffer = buffer,
|
||||||
|
MessageType = messageType,
|
||||||
|
EndOfMessage = endOfMessage
|
||||||
|
});
|
||||||
|
this.receiveEvent.Set();
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
|
||||||
|
#region WebSocket overrides
|
||||||
|
public override WebSocketCloseStatus? CloseStatus => this.closeStatus;
|
||||||
|
|
||||||
|
public override string CloseStatusDescription => this.closeStatusDescription;
|
||||||
|
|
||||||
|
public override WebSocketState State => this.state;
|
||||||
|
|
||||||
|
public override string SubProtocol => this.subProtocol;
|
||||||
|
|
||||||
|
public override void Abort()
|
||||||
|
{
|
||||||
|
throw new NotImplementedException();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
this.closeStatus = closeStatus;
|
||||||
|
this.closeStatusDescription = statusDescription;
|
||||||
|
this.receiveBuffers.Enqueue(new MessageData()
|
||||||
|
{
|
||||||
|
Buffer = new ArraySegment<byte>(new byte[] { }),
|
||||||
|
EndOfMessage = true,
|
||||||
|
MessageType = WebSocketMessageType.Close
|
||||||
|
});
|
||||||
|
this.receiveEvent.Set();
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
|
||||||
|
public override Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
throw new NotImplementedException();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override void Dispose()
|
||||||
|
{
|
||||||
|
this.receiveBuffers.Clear();
|
||||||
|
this.receiveEvent.Set();
|
||||||
|
}
|
||||||
|
|
||||||
|
public override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
if (this.receiveBuffers.Count == 0)
|
||||||
|
{
|
||||||
|
this.receiveEvent.WaitOne();
|
||||||
|
}
|
||||||
|
int bytesReceived = 0;
|
||||||
|
bool endOfMessage = true;
|
||||||
|
WebSocketMessageType messageType = WebSocketMessageType.Close;
|
||||||
|
|
||||||
|
MessageData received = null;
|
||||||
|
if (this.receiveBuffers.TryPeek(out received))
|
||||||
|
{
|
||||||
|
messageType = received.MessageType;
|
||||||
|
if (received.Buffer.Count <= buffer.Count)
|
||||||
|
{
|
||||||
|
this.receiveBuffers.TryDequeue(out received);
|
||||||
|
received.Buffer.CopyTo(buffer);
|
||||||
|
bytesReceived = received.Buffer.Count;
|
||||||
|
endOfMessage = received.EndOfMessage;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
received.Buffer.Slice(0, buffer.Count).CopyTo(buffer);
|
||||||
|
bytesReceived = buffer.Count;
|
||||||
|
endOfMessage = false;
|
||||||
|
received.Buffer = received.Buffer.Slice(buffer.Count);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Task.FromResult(new WebSocketReceiveResult(bytesReceived, messageType, endOfMessage));
|
||||||
|
}
|
||||||
|
|
||||||
|
public override Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
|
||||||
|
{
|
||||||
|
this.MessageSent?.Invoke(this, new MessageDataEventArgs()
|
||||||
|
{
|
||||||
|
Data = new MessageData()
|
||||||
|
{
|
||||||
|
Buffer = buffer,
|
||||||
|
MessageType = messageType,
|
||||||
|
EndOfMessage = endOfMessage
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
#endregion
|
||||||
|
|
||||||
|
public class MessageData
|
||||||
|
{
|
||||||
|
public ArraySegment<byte> Buffer { get; set; }
|
||||||
|
public WebSocketMessageType MessageType { get; set; }
|
||||||
|
public bool EndOfMessage { get; set; }
|
||||||
|
}
|
||||||
|
|
||||||
|
public class MessageDataEventArgs : EventArgs
|
||||||
|
{
|
||||||
|
public MessageData Data { get; set; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
432
tests/KubernetesClient.Tests/StreamDemuxerTests.cs
Normal file
432
tests/KubernetesClient.Tests/StreamDemuxerTests.cs
Normal file
@@ -0,0 +1,432 @@
|
|||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Net.WebSockets;
|
||||||
|
using System.Text;
|
||||||
|
using System.Threading;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using k8s.tests.Mock;
|
||||||
|
using Xunit;
|
||||||
|
using Xunit.Abstractions;
|
||||||
|
|
||||||
|
namespace k8s.Tests
|
||||||
|
{
|
||||||
|
public class StreamDemuxerTests
|
||||||
|
{
|
||||||
|
private readonly ITestOutputHelper testOutput;
|
||||||
|
|
||||||
|
public StreamDemuxerTests(ITestOutputHelper testOutput)
|
||||||
|
{
|
||||||
|
this.testOutput = testOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task SendDataRemoteCommand()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
List<byte> sentBuffer = new List<byte>();
|
||||||
|
ws.MessageSent += (sender, args) =>
|
||||||
|
{
|
||||||
|
sentBuffer.AddRange(args.Data.Buffer);
|
||||||
|
};
|
||||||
|
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
byte channelIndex = 12;
|
||||||
|
var stream = demuxer.GetStream(channelIndex, channelIndex);
|
||||||
|
var b = GenerateRandomBuffer(100, 0xEF);
|
||||||
|
stream.Write(b, 0, b.Length);
|
||||||
|
|
||||||
|
// Send 100 bytes, expect 1 (channel index) + 100 (payload) = 101 bytes
|
||||||
|
Assert.True(await WaitForAsync(() => sentBuffer.Count == 101), $"Demuxer error: expect to send 101 bytes, but actually send {sentBuffer.Count} bytes.");
|
||||||
|
Assert.True(sentBuffer[0] == channelIndex, "The first sent byte is not channel index!");
|
||||||
|
Assert.True(sentBuffer[1] == 0xEF, "Incorrect payload!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task SendMultipleDataRemoteCommand()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
List<byte> sentBuffer = new List<byte>();
|
||||||
|
ws.MessageSent += (sender, args) =>
|
||||||
|
{
|
||||||
|
sentBuffer.AddRange(args.Data.Buffer);
|
||||||
|
};
|
||||||
|
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
byte channelIndex = 12;
|
||||||
|
var stream = demuxer.GetStream(channelIndex, channelIndex);
|
||||||
|
var b = GenerateRandomBuffer(100, 0xEF);
|
||||||
|
stream.Write(b, 0, b.Length);
|
||||||
|
b = GenerateRandomBuffer(200, 0xAB);
|
||||||
|
stream.Write(b, 0, b.Length);
|
||||||
|
|
||||||
|
// Send 300 bytes in 2 messages, expect 1 (channel index) * 2 + 300 (payload) = 302 bytes
|
||||||
|
Assert.True(await WaitForAsync(() => sentBuffer.Count == 302), $"Demuxer error: expect to send 302 bytes, but actually send {sentBuffer.Count} bytes.");
|
||||||
|
Assert.True(sentBuffer[0] == channelIndex, "The first sent byte is not channel index!");
|
||||||
|
Assert.True(sentBuffer[1] == 0xEF, "The first part of payload incorrect!");
|
||||||
|
Assert.True(sentBuffer[101] == channelIndex, "The second message first byte is not channel index!");
|
||||||
|
Assert.True(sentBuffer[102] == 0xAB, "The second part of payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ReceiveDataRemoteCommand()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
List<byte> receivedBuffer = new List<byte>();
|
||||||
|
byte channelIndex = 12;
|
||||||
|
var stream = demuxer.GetStream(channelIndex, channelIndex);
|
||||||
|
|
||||||
|
// Receive 600 bytes in 3 messages. Exclude 1 channel index byte per message, expect 597 bytes payload.
|
||||||
|
int expectedCount = 597;
|
||||||
|
|
||||||
|
var t = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(100, channelIndex, 0xAA, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(200, channelIndex, 0xAB, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(300, channelIndex, 0xAC, false)), WebSocketMessageType.Binary, true);
|
||||||
|
|
||||||
|
await WaitForAsync(() => receivedBuffer.Count == expectedCount);
|
||||||
|
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "normal", CancellationToken.None);
|
||||||
|
});
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await t;
|
||||||
|
|
||||||
|
Assert.True(receivedBuffer.Count == expectedCount, $"Demuxer error: expect to receive {expectedCount} bytes, but actually got {receivedBuffer.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer[0] == 0xAA, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[98] == 0xAA, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[99] == 0xAB, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[297] == 0xAB, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[298] == 0xAC, "The third payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[596] == 0xAC, "The third payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ReceiveDataPortForward()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws, StreamType.PortForward);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
List<byte> receivedBuffer = new List<byte>();
|
||||||
|
byte channelIndex = 12;
|
||||||
|
var stream = demuxer.GetStream(channelIndex, channelIndex);
|
||||||
|
|
||||||
|
// Receive 600 bytes in 3 messages. Exclude 1 channel index byte per message, and 2 port bytes in the first message.
|
||||||
|
// expect 600 - 3 - 2 = 595 bytes payload.
|
||||||
|
int expectedCount = 595;
|
||||||
|
|
||||||
|
var t = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(100, channelIndex, 0xB1, true)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(200, channelIndex, 0xB2, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(300, channelIndex, 0xB3, false)), WebSocketMessageType.Binary, true);
|
||||||
|
|
||||||
|
await WaitForAsync(() => receivedBuffer.Count == expectedCount);
|
||||||
|
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "normal", CancellationToken.None);
|
||||||
|
});
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await t;
|
||||||
|
|
||||||
|
Assert.True(receivedBuffer.Count == expectedCount, $"Demuxer error: expect to receive {expectedCount} bytes, but actually got {receivedBuffer.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer[0] == 0xB1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[96] == 0xB1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[97] == 0xB2, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[295] == 0xB2, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[296] == 0xB3, "The third payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[594] == 0xB3, "The third payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ReceiveDataPortForwardOneByteMessage()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws, StreamType.PortForward);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
List<byte> receivedBuffer = new List<byte>();
|
||||||
|
byte channelIndex = 12;
|
||||||
|
var stream = demuxer.GetStream(channelIndex, channelIndex);
|
||||||
|
|
||||||
|
// Receive 402 bytes in 3 buffers of 2 messages. Exclude 1 channel index byte per message, and 2 port bytes in the first message.
|
||||||
|
// expect 402 - 1 x 2 - 2 = 398 bytes payload.
|
||||||
|
int expectedCount = 398;
|
||||||
|
|
||||||
|
var t = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(2, channelIndex, 0xC1, true)), WebSocketMessageType.Binary, false);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(100, channelIndex, 0xC2, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(300, channelIndex, 0xC3, false)), WebSocketMessageType.Binary, true);
|
||||||
|
|
||||||
|
await WaitForAsync(() => receivedBuffer.Count == expectedCount);
|
||||||
|
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "normal", CancellationToken.None);
|
||||||
|
});
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
await t;
|
||||||
|
|
||||||
|
Assert.True(receivedBuffer.Count == expectedCount, $"Demuxer error: expect to receive {expectedCount} bytes, but actually got {receivedBuffer.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer[0] == 0xC2, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[98] == 0xC2, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[99] == 0xC3, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer[397] == 0xC3, "The second payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ReceiveDataRemoteCommandMultipleStream()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
List<byte> receivedBuffer1 = new List<byte>();
|
||||||
|
byte channelIndex1 = 1;
|
||||||
|
var stream1 = demuxer.GetStream(channelIndex1, channelIndex1);
|
||||||
|
List<byte> receivedBuffer2 = new List<byte>();
|
||||||
|
byte channelIndex2 = 2;
|
||||||
|
var stream2 = demuxer.GetStream(channelIndex2, channelIndex2);
|
||||||
|
|
||||||
|
// stream 1: receive 100 + 300 = 400 bytes, exclude 1 channel index per message, expect 400 - 1 x 2 = 398 bytes.
|
||||||
|
int expectedCount1 = 398;
|
||||||
|
|
||||||
|
// stream 2: receive 200 bytes, exclude 1 channel index per message, expect 200 - 1 = 199 bytes.
|
||||||
|
int expectedCount2 = 199;
|
||||||
|
|
||||||
|
var t1 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
// Simulate WebSocket received remote data to multiple streams
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(100, channelIndex1, 0xD1, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(200, channelIndex2, 0xD2, false)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(300, channelIndex1, 0xD3, false)), WebSocketMessageType.Binary, true);
|
||||||
|
|
||||||
|
await WaitForAsync(() => receivedBuffer1.Count == expectedCount1);
|
||||||
|
await WaitForAsync(() => receivedBuffer2.Count == expectedCount2);
|
||||||
|
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "normal", CancellationToken.None);
|
||||||
|
});
|
||||||
|
var t2 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream1.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer1.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
var t3 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream2.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer2.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
await Task.WhenAll(t1, t2, t3);
|
||||||
|
|
||||||
|
Assert.True(receivedBuffer1.Count == expectedCount1, $"Demuxer error: expect to receive {expectedCount1} bytes, but actually got {receivedBuffer1.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer2.Count == expectedCount2, $"Demuxer error: expect to receive {expectedCount2} bytes, but actually got {receivedBuffer2.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer1[0] == 0xD1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[98] == 0xD1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[99] == 0xD3, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[397] == 0xD3, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer2[0] == 0xD2, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer2[198] == 0xD2, "The first payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ReceiveDataPortForwardMultipleStream()
|
||||||
|
{
|
||||||
|
using (MockWebSocket ws = new MockWebSocket())
|
||||||
|
{
|
||||||
|
StreamDemuxer demuxer = new StreamDemuxer(ws, StreamType.PortForward);
|
||||||
|
Task.Run(() => demuxer.Start());
|
||||||
|
|
||||||
|
List<byte> receivedBuffer1 = new List<byte>();
|
||||||
|
byte channelIndex1 = 1;
|
||||||
|
var stream1 = demuxer.GetStream(channelIndex1, channelIndex1);
|
||||||
|
List<byte> receivedBuffer2 = new List<byte>();
|
||||||
|
byte channelIndex2 = 2;
|
||||||
|
var stream2 = demuxer.GetStream(channelIndex2, channelIndex2);
|
||||||
|
|
||||||
|
// stream 1: receive 100 + 300 = 400 bytes, exclude 1 channel index per message, exclude port bytes in the first message,
|
||||||
|
// expect 400 - 1 x 2 - 2 = 396 bytes.
|
||||||
|
int expectedCount1 = 396;
|
||||||
|
|
||||||
|
// stream 2: receive 200 bytes, exclude 1 channel index per message, exclude port bytes in the first message,
|
||||||
|
// expect 200 - 1 - 2 = 197 bytes.
|
||||||
|
int expectedCount2 = 197;
|
||||||
|
|
||||||
|
var t1 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
// Simulate WebSocket received remote data to multiple streams
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(100, channelIndex1, 0xE1, true)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(200, channelIndex2, 0xE2, true)), WebSocketMessageType.Binary, true);
|
||||||
|
await ws.InvokeReceiveAsync(new ArraySegment<byte>(GenerateRandomBuffer(300, channelIndex1, 0xE3, false)), WebSocketMessageType.Binary, true);
|
||||||
|
|
||||||
|
await WaitForAsync(() => receivedBuffer1.Count == expectedCount1);
|
||||||
|
await WaitForAsync(() => receivedBuffer2.Count == expectedCount2);
|
||||||
|
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "normal", CancellationToken.None);
|
||||||
|
});
|
||||||
|
var t2 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream1.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer1.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
var t3 = Task.Run(async () =>
|
||||||
|
{
|
||||||
|
var buffer = new byte[50];
|
||||||
|
while (true)
|
||||||
|
{
|
||||||
|
var cRead = await stream2.ReadAsync(buffer, 0, buffer.Length);
|
||||||
|
if (cRead == 0)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < cRead; i++)
|
||||||
|
{
|
||||||
|
receivedBuffer2.Add(buffer[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
await Task.WhenAll(t1, t2, t3);
|
||||||
|
|
||||||
|
Assert.True(receivedBuffer1.Count == expectedCount1, $"Demuxer error: expect to receive {expectedCount1} bytes, but actually got {receivedBuffer1.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer2.Count == expectedCount2, $"Demuxer error: expect to receive {expectedCount2} bytes, but actually got {receivedBuffer2.Count} bytes.");
|
||||||
|
Assert.True(receivedBuffer1[0] == 0xE1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[96] == 0xE1, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[97] == 0xE3, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer1[395] == 0xE3, "The second payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer2[0] == 0xE2, "The first payload incorrect!");
|
||||||
|
Assert.True(receivedBuffer2[196] == 0xE2, "The first payload incorrect!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
private static byte[] GenerateRandomBuffer(int length, byte channelIndex, byte content, bool portForward)
|
||||||
|
{
|
||||||
|
var buffer = GenerateRandomBuffer(length, content);
|
||||||
|
buffer[0] = channelIndex;
|
||||||
|
if (portForward)
|
||||||
|
{
|
||||||
|
if (length > 1)
|
||||||
|
{
|
||||||
|
buffer[1] = 0xFF; // the first port bytes
|
||||||
|
}
|
||||||
|
if (length > 2)
|
||||||
|
{
|
||||||
|
buffer[2] = 0xFF; // the 2nd port bytes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte[] GenerateRandomBuffer(int length, byte content)
|
||||||
|
{
|
||||||
|
var buffer = new byte[length];
|
||||||
|
for (int i = 0; i < length; i++)
|
||||||
|
{
|
||||||
|
buffer[i] = content;
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task<bool> WaitForAsync(Func<bool> handler, float waitForSeconds = 1)
|
||||||
|
{
|
||||||
|
Stopwatch w = Stopwatch.StartNew();
|
||||||
|
try
|
||||||
|
{
|
||||||
|
do
|
||||||
|
{
|
||||||
|
if (handler())
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
await Task.Delay(10);
|
||||||
|
} while (w.Elapsed.Duration().TotalSeconds < waitForSeconds);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
w.Stop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user