接上次Netty实现Websocket协议通信的例子
https://blog.csdn.net/xxkalychen/article/details/115903261?spm=1001.2014.3001.5501
整个实例对于客户端的标识都是channel的id或者远程地址,并不直观,我们希望有更加清楚的标识,比如用户名。我们可以制定协议或者上行报文,来识别哪些消息是身份标识,哪些是聊天内容。不过这样我们就要在连接成功之后第一次发送消息时确定。
其实我们也可以在连接的时候就带上我们的用户名,就像这样:
ws://localhost:8001/chat/chris
我们把用户名加在请求路径中,在连接的时候就进行注册,把用户名和channel绑定在一起。我们发言的时候就可以带着自己的用户名标识了。

思路是我们在Websocket协议处理之前拿到请求数据,把用户名截取出来,跟channel绑定。因为Websocket第一次请求就是http请求,我们是可以拿到请求URI信息的。
我们需要增加三个文件,一个用来放在Websocket协议处理之前用来绑定用户名的自定义连接处理器,一个是自己封装的channel,其中扩展了内容,包含了用户名(其实后来发现,几经修改,不该这么做,直接用map建立对应关系就好了),还有一个是全局的客户端管理类。
依赖
compile 'io.netty:netty-all:4.1.63.Final'
1. 自定义channel
package com.chris.ws.server.handler;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.*;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;
import java.net.SocketAddress;
/**
* @author Chris Chan
* Create on 2021/4/20 14:21
* Use for:
* Explain:
*/
public class WebsocketChannel implements Channel {
private String username;
private Channel channel;
private WebsocketChannel() {
}
private WebsocketChannel(String username, Channel channel) {
this.username = username;
this.channel = channel;
}
public static WebsocketChannel create(String username, Channel channel) {
return new WebsocketChannel(username, channel);
}
public String getUsername() {
return username;
}
public Channel getChannel() {
return channel;
}
@Override
public ChannelId id() {
return this.channel.id();
}
@Override
public EventLoop eventLoop() {
return this.channel.eventLoop();
}
@Override
public Channel parent() {
return this.channel.parent();
}
@Override
public ChannelConfig config() {
return this.channel.config();
}
@Override
public boolean isOpen() {
return this.channel.isOpen();
}
@Override
public boolean isRegistered() {
return this.channel.isRegistered();
}
@Override
public boolean isActive() {
return this.channel.isActive();
}
@Override
public ChannelMetadata metadata() {
return null;
}
@Override
public SocketAddress localAddress() {
return this.channel.localAddress();
}
@Override
public SocketAddress remoteAddress() {
return this.channel.remoteAddress();
}
@Override
public ChannelFuture closeFuture() {
return this.channel.closeFuture();
}
@Override
public boolean isWritable() {
return this.channel.isWritable();
}
@Override
public long bytesBeforeUnwritable() {
return this.channel.bytesBeforeUnwritable();
}
@Override
public long bytesBeforeWritable() {
return this.channel.bytesBeforeWritable();
}
@Override
public Unsafe unsafe() {
return this.channel.unsafe();
}
@Override
public ChannelPipeline pipeline() {
return this.channel.pipeline();
}
@Override
public ByteBufAllocator alloc() {
return this.channel.alloc();
}
@Override
public ChannelFuture bind(SocketAddress localAddress) {
return this.channel.bind(localAddress);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress) {
return this.channel.connect(remoteAddress);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) {
return this.channel.connect(remoteAddress, localAddress);
}
@Override
public ChannelFuture disconnect() {
return this.channel.disconnect();
}
@Override
public ChannelFuture close() {
return this.channel.close();
}
@Override
public ChannelFuture deregister() {
return this.channel.deregister();
}
@Override
public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) {
return this.channel.bind(localAddress, promise);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) {
return this.channel.connect(remoteAddress, promise);
}
@Override
public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) {
return this.channel.connect(remoteAddress, localAddress, promise);
}
@Override
public ChannelFuture disconnect(ChannelPromise promise) {
return this.channel.disconnect(promise);
}
@Override
public ChannelFuture close(ChannelPromise promise) {
return this.channel.close();
}
@Override
public ChannelFuture deregister(ChannelPromise promise) {
return this.channel.deregister(promise);
}
@Override
public Channel read() {
return this.channel.read();
}
@Override
public ChannelFuture write(Object msg) {
return this.channel.write(msg);
}
@Override
public ChannelFuture write(Object msg, ChannelPromise promise) {
return this.channel.write(msg, promise);
}
@Override
public Channel flush() {
return this.channel.flush();
}
@Override
public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) {
return this.channel.write(msg, promise);
}
@Override
public ChannelFuture writeAndFlush(Object msg) {
return this.channel.writeAndFlush(msg);
}
@Override
public ChannelPromise newPromise() {
return this.channel.newPromise();
}
@Override
public ChannelProgressivePromise newProgressivePromise() {
return this.channel.newProgressivePromise();
}
@Override
public ChannelFuture newSucceededFuture() {
return this.channel.newSucceededFuture();
}
@Override
public ChannelFuture newFailedFuture(Throwable cause) {
return this.channel.newFailedFuture(cause);
}
@Override
public ChannelPromise voidPromise() {
return this.channel.voidPromise();
}
@Override
public <T> Attribute<T> attr(AttributeKey<T> key) {
return this.channel.attr(key);
}
@Override
public <T> boolean hasAttr(AttributeKey<T> key) {
return this.channel.hasAttr(key);
}
@Override
public int compareTo(Channel o) {
return this.channel.compareTo(o);
}
}
我们使用了一个代理模式来增强channel,接口方法太多,你也可以不实现,但是如果你要使用channelGroup自己的方法,那么相关的方法都要实现一遍。我这里贴出来,可以直接复制。
2. 全局客户端管理类。其主要目的就是处理用户名与channel的绑定等业务。
package com.chris.ws.server.handler;
import io.netty.channel.Channel;
import io.netty.channel.ChannelId;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.group.DefaultChannelGroup;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.concurrent.DefaultEventExecutor;
import java.util.HashMap;
import java.util.Map;
/**
* @author Chris Chan
* Create on 2021/4/20 14:44
* Use for:
* Explain:
*/
public class ClientManager {
//private static ChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE);
private static ChannelGroup channelGroup = new DefaultChannelGroup(new DefaultEventExecutor());
private static Map<ChannelId, String> userMap = new HashMap<>(16);
public static ChannelGroup getChannelGroup() {
return channelGroup;
}
/**
* 添加一个系统默认的channel
*
* @param username
* @param channel
*/
public static void addChannel(String username, Channel channel) {
addWebsocketChannel(WebsocketChannel.create(username, channel));
}
/**
* 添加一个构建好的WebsocketChannel
*
* @param websocketChannel
*/
public static void addWebsocketChannel(WebsocketChannel websocketChannel) {
String username = websocketChannel.getUsername();
if (null == username || "".equalsIgnoreCase(username.trim())) {
throw new RuntimeException("username can not be empty.");
}
if (userMap.values().contains(username.trim())) {
throw new RuntimeException("the username is exists.");
}
channelGroup.add(websocketChannel);
userMap.put(websocketChannel.id(), username.trim());
}
/**
* 查找Channel
*
* @param channel
* @return
*/
public static WebsocketChannel find(Channel channel) {
return find(channel.id());
}
/**
* 查找Channel
*
* @param channelId
* @return
*/
public static WebsocketChannel find(ChannelId channelId) {
Channel channel = channelGroup.find(channelId);
if (channel instanceof WebsocketChannel) {
return (WebsocketChannel) channel;
} else {
return null;
}
}
/**
* 获取用户名
*
* @param channel
* @return
*/
public static String getUsername(Channel channel) {
return getUsername(channel.id());
}
/**
* 获取用户名
*
* @param channelId
* @return
*/
public static String getUsername(ChannelId channelId) {
WebsocketChannel websocketChannel = find(channelId);
String username;
if (null == websocketChannel) {
String val = userMap.get(channelId);
if (null == val || "".equalsIgnoreCase(val)) {
username = "unknown";
} else {
username = val;
}
} else {
username = websocketChannel.getUsername();
}
return username;
}
/**
* 移除Channel
*
* @param channel
*/
public static void removeChannel(Channel channel) {
removeChannel(channel.id());
}
/**
* 根据ChannelId移除Channel
*
* @param channelId
*/
public static void removeChannel(ChannelId channelId) {
Channel channel = channelGroup.find(channelId);
if (null == channel) {
if (userMap.keySet().contains(channelId)) {
userMap.remove(channelId);
}
return;
}
if (channel instanceof WebsocketChannel) {
removeWebsocketChannel((WebsocketChannel) channel);
} else {
channelGroup.remove(channel);
}
}
/**
* 移除WebsocketChannel
*
* @param websocketChannel
*/
public static void removeWebsocketChannel(WebsocketChannel websocketChannel) {
if (null == websocketChannel || !channelGroup.contains(websocketChannel)) {
return;
}
ChannelId channelId = websocketChannel.id();
channelGroup.remove(websocketChannel);
if (userMap.keySet().contains(channelId)) {
userMap.remove(channelId);
}
}
/**
* 客户端总数
*
* @return
*/
public static int size() {
return channelGroup.size();
}
/**
* 群发消息
*
* @param msg
*/
public static void send(String msg) {
channelGroup.writeAndFlush(new TextWebSocketFrame(msg));
}
}
其中封装了一些方法,方便使用。也可以直接复制。
3. 自定义连接处理器。用来提取绑定用户名和channel。
package com.chris.ws.server.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
/**
* @author Chris Chan
* Create on 2021/4/20 13:58
* Use for:
* Explain:
*/
public class WebsocketConnectHandler extends ChannelInboundHandlerAdapter {
private String path = "/";
public WebsocketConnectHandler(String path) {
this.path = path;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof FullHttpRequest) {
FullHttpRequest httpRequest = (FullHttpRequest) msg;
String uri = httpRequest.uri();
String origin = httpRequest.headers().get("Origin");
if (null == origin) {
ctx.close();
} else {
if (null != uri && uri.contains(path)) {
String[] split = uri.split("/");
if (split.length == 3) {
String username = split[2];
String info = username + " 上线 大家欢迎 目前在线 " + (ClientManager.size() + 1) + "人";
System.out.println(info);
ClientManager.send(info);
ClientManager.addChannel(username, ctx.channel());
}
httpRequest.setUri(path);
}
}
}
super.channelRead(ctx, msg);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
String username = ClientManager.getUsername(ctx.channel());
ClientManager.removeChannel(ctx.channel().id());
String info = username + " 下线 目前在线 " + ClientManager.size() + " 人";
System.out.println(info);
ClientManager.send(info);
}
}
4. 自定义消息处理器。主要工作就是转发消息。
package com.chris.ws.server.handler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
/**
* @author Chris Chan
* Create on 2021/4/18 15:14
* Use for:
* Explain:
*/
public class WebSocketServerHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception {
String info = ClientManager.getUsername(ctx.channel()) + " : " + msg.text();
System.out.println(info);
ClientManager.send(info);
}
}
5. 初始化器
package com.chris.ws.server.handler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
/**
* @author Chris Chan
* Create on 2021/4/18 15:10
* Use for:
* Explain:
*/
public class WebSocketServerInitializer extends ChannelInitializer<SocketChannel> {
@Override
protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline()
.addLast("HttpServerCodec", new HttpServerCodec())
.addLast("ChunkedWriteHandler", new ChunkedWriteHandler())
.addLast("HttpObjectAggregator", new HttpObjectAggregator(8192))
.addLast("WebsocketConnectHandler", new WebsocketConnectHandler("/chat"))
.addLast("WebSocketServerProtocolHandler", new WebSocketServerProtocolHandler("/chat"))
.addLast("WebSocketServerHandler", new WebSocketServerHandler());
}
}
就增加了一个自定义的连接管理器。注意顺序,要放在Websocket协议处理器之前。
6. 测试类。没有一点改动。
package com.chris.ws.server;
import com.chris.ws.server.handler.WebSocketServerInitializer;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
/**
* @author Chris Chan
* Create on 2021/4/18 15:09
* Use for:
* Explain: WebSocket的支持
*/
public class WebSocketServer {
public static void main(String[] args) {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap serverBootstrap = new ServerBootstrap();
serverBootstrap
.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new WebSocketServerInitializer());
serverBootstrap
.bind(8001)
.sync()
.channel()
.closeFuture()
.sync();
} catch (Exception e) {
e.printStackTrace();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
测试,启动服务器,找几个在线Websocket测试的网站进行测试。