跳到主要内容

一文入门 Seata 网络通信源码

· 阅读需 55 分钟

在前几篇文章中,我们详细聊了聊 Seata 的 XA、AT 以及 TCC 模式,它们都是在 Seata 定义的全局框架下的不同的事务模式。

我们知道,在 Seata 中,有三类角色,TC、RM、TM,Seata Server 作为 TC 协调分支事务的提交和回滚,各个资源作为 RM 和 TM,那么这三者之间是如何通信的?

所以,这篇文章就来看看 Seata 底层是如何进行网络通信的。

整体类层次结构

我们先着眼大局,看一看 Seata 整个 RPC 的类层次结构。

image-20241217222005964

从类结构层次可以看出来,AbstractNettyRemoting 是整个 Seata 网络通信的一个顶层抽象类。

在这个类中主要实现了一些 RPC 的基础通用方法,比如同步调用 sendSync、异步调用 sendAsync 等。

事实上,就网络调用来说,无非就是同步调用和异步调用,像其他的什么请求和响应都只是报文内容的区分。

所以,在 Seata 中,我个人认为还差一个顶层的接口 Remoting,类似于下面这样的:

import io.netty.channel.Channel;
import java.util.concurrent.TimeoutException;

public interface Remoting<Req, Resp> {

/**
* 同步调用
*/
Resp sendSync(Channel channel, Req request, long timeout) throws TimeoutException;

/**
* 异步调用
*/
void sendAsync(Channel channel, Req request);
}

在 AbstractNettyRemoting 实现了通用的网络调用方法,但是不同角色在这方面还是有一些区分的,比如对于 Server 来说,它的请求调用需要知道向哪个客户端发送,而对于 TM、RM 来说,它们发送请求直接发就行,不需要指定某个特定的 TC 服务,只需要在实现类通过负载均衡算法找到合适的 Server 节点就行。

所以就区分出了 RemotingServer 和 RemotingClient,但是底层还是要依赖 AbstractNettyRemoting 进行网络调用的,所以它们各自有子类实现了 AbstractNettyRemoting。

可以说 Seata 的这种设计在我看来是非常不错的,对于这种 CS 架构的远程通信,可以算一种通用的设计方案。

如何启动 Server 和 Client

聊完了 Seata 底层的类层次,我们再分别以 Server 和 Client 的视角来看它们是如何启动的,以及在启动的时候需要做些什么事情。

Server 是怎么启动的

Seata Server 作为一个独立的 SpringBoot 项目,要怎么样才能在 SpringBoot 启动的时候自动做点事呢?

Seata 的做法是实现了 CommandLineRunner 接口,至于这里面的原理就不是本篇文章讨论的内容了。

我们主要关注它的 run 方法:

// org.apache.seata.server.ServerRunner#run
public void run(String... args) {
try {
long start = System.currentTimeMillis();
seataServer.start(args);
started = true;
long cost = System.currentTimeMillis() - start;
LOGGER.info("\r\n you can visit seata console UI on http://127.0.0.1:{}. \r\n log path: {}.", this.port, this.logPath);
LOGGER.info("seata server started in {} millSeconds", cost);
} catch (Throwable e) {
started = Boolean.FALSE;
LOGGER.error("seata server start error: {} ", e.getMessage(), e);
System.exit(-1);
}
}

这其中核心的逻辑就在 seataServer.start() 方法中:

// org.apache.seata.server.Server#start
public void start(String[] args) {
// 参数解析器,用于解析 sh 的启动参数
ParameterParser parameterParser = new ParameterParser(args);
// initialize the metrics
MetricsManager.get().init();
ThreadPoolExecutor workingThreads = new ThreadPoolExecutor(NettyServerConfig.getMinServerPoolSize(),
NettyServerConfig.getMaxServerPoolSize(), NettyServerConfig.getKeepAliveTime(), TimeUnit.SECONDS,
new LinkedBlockingQueue<>(NettyServerConfig.getMaxTaskQueueSize()),
new NamedThreadFactory("ServerHandlerThread", NettyServerConfig.getMaxServerPoolSize()), new ThreadPoolExecutor.CallerRunsPolicy());
// 127.0.0.1 and 0.0.0.0 are not valid here.
if (NetUtil.isValidIp(parameterParser.getHost(), false)) {
XID.setIpAddress(parameterParser.getHost());
} else {
String preferredNetworks = ConfigurationFactory.getInstance().getConfig(REGISTRY_PREFERED_NETWORKS);
if (StringUtils.isNotBlank(preferredNetworks)) {
XID.setIpAddress(NetUtil.getLocalIp(preferredNetworks.split(REGEX_SPLIT_CHAR)));
} else {
XID.setIpAddress(NetUtil.getLocalIp());
}
}
/**
* 主要做这么几件事:
* 1. 设置 workingThreads 为 AbstractNettyRemoting 的 messageExecutor 处理器
* 2. 创建 ServerBootstrap,配置 Boss 和 Worker,并且设置 Seata Server 需要监听的端口
* 3. 设置出栈、入栈处理器 ServerHandler,它是一个 ChannelDuplexHandler 复合的处理器
*/
NettyRemotingServer nettyRemotingServer = new NettyRemotingServer(workingThreads);
XID.setPort(nettyRemotingServer.getListenPort());
UUIDGenerator.init(parameterParser.getServerNode());
ConfigurableListableBeanFactory beanFactory = ((GenericWebApplicationContext) ObjectHolder.INSTANCE.getObject(OBJECT_KEY_SPRING_APPLICATION_CONTEXT)).getBeanFactory();
DefaultCoordinator coordinator = DefaultCoordinator.getInstance(nettyRemotingServer);
if (coordinator instanceof ApplicationListener) {
beanFactory.registerSingleton(NettyRemotingServer.class.getName(), nettyRemotingServer);
beanFactory.registerSingleton(DefaultCoordinator.class.getName(), coordinator);
((GenericWebApplicationContext) ObjectHolder.INSTANCE.getObject(OBJECT_KEY_SPRING_APPLICATION_CONTEXT)).addApplicationListener((ApplicationListener<?>) coordinator);
}
// log store mode: file, db, redis
SessionHolder.init();
LockerManagerFactory.init();
// 初始化一系列定时线程池,用于重试事务提交/回滚等
coordinator.init();
// 设置事务处理 Handler 为 DefaultCoordinator
nettyRemotingServer.setHandler(coordinator);
serverInstance.serverInstanceInit();
// let ServerRunner do destroy instead ShutdownHook, see https://github.com/seata/seata/issues/4028
ServerRunner.addDisposable(coordinator);
// Server 初始化
nettyRemotingServer.init();
}

最后的 nettyRemotingServer.init() 是整个 Seata Server 启动的重要逻辑,主要做了这么几件事:

  1. 注册一系列处理器
  2. 初始化一个定时线程池,用于清理过期的 MessageFuture
  3. 启动 ServerBootStrap 并将 TC 服务注册到注册中心,比如 Nacos

注册处理器

在 Seata 内部,用一个 Pair 对象关联了处理器和线程池,如下:

package org.apache.seata.core.rpc.processor;

public final class Pair<T1, T2> {

private final T1 first;
private final T2 second;

public Pair(T1 first, T2 second) {
this.first = first;
this.second = second;
}

public T1 getFirst() {
return first;
}

public T2 getSecond() {
return second;
}
}

而注册处理器本质就是将报文类型、处理该报文的处理器以及具体执行的线程池关联起来,存到一张哈希表中。

// AbstractNettyRemotingServer
protected final Map<Integer/*MessageType*/, Pair<RemotingProcessor, ExecutorService>> processorTable = new HashMap<>(32);
// org.apache.seata.core.rpc.netty.NettyRemotingServer#registerProcessor
private void registerProcessor() {
// 1. registry on request message processor
ServerOnRequestProcessor onRequestProcessor = new ServerOnRequestProcessor(this, getHandler());
ShutdownHook.getInstance().addDisposable(onRequestProcessor);
super.registerProcessor(MessageType.TYPE_BRANCH_REGISTER, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_BRANCH_STATUS_REPORT, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_BEGIN, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_COMMIT, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_LOCK_QUERY, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_REPORT, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_ROLLBACK, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_GLOBAL_STATUS, onRequestProcessor, messageExecutor);
super.registerProcessor(MessageType.TYPE_SEATA_MERGE, onRequestProcessor, messageExecutor);
// 2. registry on response message processor
ServerOnResponseProcessor onResponseProcessor = new ServerOnResponseProcessor(getHandler(), getFutures());
super.registerProcessor(MessageType.TYPE_BRANCH_COMMIT_RESULT, onResponseProcessor, branchResultMessageExecutor);
super.registerProcessor(MessageType.TYPE_BRANCH_ROLLBACK_RESULT, onResponseProcessor, branchResultMessageExecutor);
// 3. registry rm message processor
RegRmProcessor regRmProcessor = new RegRmProcessor(this);
super.registerProcessor(MessageType.TYPE_REG_RM, regRmProcessor, messageExecutor);
// 4. registry tm message processor
RegTmProcessor regTmProcessor = new RegTmProcessor(this);
super.registerProcessor(MessageType.TYPE_REG_CLT, regTmProcessor, null);
// 5. registry heartbeat message processor
ServerHeartbeatProcessor heartbeatMessageProcessor = new ServerHeartbeatProcessor(this);
super.registerProcessor(MessageType.TYPE_HEARTBEAT_MSG, heartbeatMessageProcessor, null);
}


// org.apache.seata.core.rpc.netty.AbstractNettyRemotingServer#registerProcessor
public void registerProcessor(int messageType, RemotingProcessor processor, ExecutorService executor) {
Pair<RemotingProcessor, ExecutorService> pair = new Pair<>(processor, executor);
this.processorTable.put(messageType, pair);
}

你可能会注意到,在注册处理器时,有一些注册时传入的线程池是 null,那么对应的报文会由哪个线程执行呢?

后面我们会提到。

初始化定时线程池

// org.apache.seata.core.rpc.netty.AbstractNettyRemoting#init
public void init() {
timerExecutor.scheduleAtFixedRate(() -> {
for (Map.Entry<Integer, MessageFuture> entry : futures.entrySet()) {
MessageFuture future = entry.getValue();
if (future.isTimeout()) {
futures.remove(entry.getKey());
RpcMessage rpcMessage = future.getRequestMessage();
future.setResultMessage(new TimeoutException(String.format("msgId: %s, msgType: %s, msg: %s, request timeout",
rpcMessage.getId(), String.valueOf(rpcMessage.getMessageType()), rpcMessage.getBody().toString())));
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("timeout clear future: {}", entry.getValue().getRequestMessage().getBody());
}
}
}
nowMills = System.currentTimeMillis();
}, TIMEOUT_CHECK_INTERVAL, TIMEOUT_CHECK_INTERVAL, TimeUnit.MILLISECONDS);
}

这个没啥好说的,就是初始化了一个定时线程池定时清理那些超时的 MessageFuture,这里 MessageFuture 是 Seata 将异步调用转为同步调用的关键,我们后面也会详细说到。

启动 ServerBootStrap

最后启动 ServerBootStrap,这差不多就是 Netty 的内容了。

// org.apache.seata.core.rpc.netty.NettyServerBootstrap#start
public void start() {
int port = getListenPort();
this.serverBootstrap.group(this.eventLoopGroupBoss, this.eventLoopGroupWorker)
.channel(NettyServerConfig.SERVER_CHANNEL_CLAZZ)
.option(ChannelOption.SO_BACKLOG, nettyServerConfig.getSoBackLogSize())
.option(ChannelOption.SO_REUSEADDR, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.SO_SNDBUF, nettyServerConfig.getServerSocketSendBufSize())
.childOption(ChannelOption.SO_RCVBUF, nettyServerConfig.getServerSocketResvBufSize())
.childOption(ChannelOption.WRITE_BUFFER_WATER_MARK, new WriteBufferWaterMark(nettyServerConfig.getWriteBufferLowWaterMark(), nettyServerConfig.getWriteBufferHighWaterMark()))
.localAddress(new InetSocketAddress(port))
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
// 多版本协议解码器
MultiProtocolDecoder multiProtocolDecoder = new MultiProtocolDecoder(channelHandlers);
ch.pipeline()
.addLast(new IdleStateHandler(nettyServerConfig.getChannelMaxReadIdleSeconds(), 0, 0))
.addLast(multiProtocolDecoder);
}
});
try {
this.serverBootstrap.bind(port).sync();
LOGGER.info("Server started, service listen port: {}", getListenPort());
InetSocketAddress address = new InetSocketAddress(XID.getIpAddress(), XID.getPort());
for (RegistryService<?> registryService : MultiRegistryFactory.getInstances()) {
// 注册服务
registryService.register(address);
}
initialized.set(true);
} catch (SocketException se) {
throw new RuntimeException("Server start failed, the listen port: " + getListenPort(), se);
} catch (Exception exx) {
throw new RuntimeException("Server start failed", exx);
}
}

ServerBootstrap 启动时的 childOption 属于网络部分的内容,我们不过多解释。

这里你可能有一点疑问,在 pipeline 中仅仅只是添加了一个 MultiProtocolDecoder 解码器,那业务处理器呢?

事实上,MultiProtocolDecoder 的构造参数中的 channelHandlers 就是 ServerHandler,它是在创建 NettyRemotingServer 时就被设置的。

至于为什么要这样做,其实是和 Seata 的多版本协议相关。

当 Seata Server 启动后第一次进行解码时,会将 MultiProtocolDecoder 从 pipeline 中移除,根据版本选择具体的 Encoder 和 Decoder 并添加到 pipeline 中,此时,也会将 ServerHandler 添加到 pipeline。

Client 是怎么启动的

对于 Client 来说,由于我们一般是在 SpringBoot 中使用 Seata,所以我们需要关注的点在 SeataAutoConfiguration 类中。

在这个类里面创建了一个 GlobalTransactionScanner 对象,我们注意到它实现了 InitializingBean,所以将目光转移到 afterPropertiesSet 方法上。

果然在这个方法里面进行了 TM 和 RM 的初始化。

TM 的初始化

对于 TM 来说,初始化的逻辑如下:

public static void init(String applicationId, String transactionServiceGroup, String accessKey, String secretKey) {
/**
* 主要做这么几件事
* 1. 创建线程池作为 AbstractNettyRemotingClient 的 messageExecutor
* 2. 设置事务角色 transactionRole 为 TM_ROLE
* 3. 创建 Bootstrap 并设置出栈、入栈处理器 ClientHandler
* 4. 创建客户端 Channel 管理器 NettyClientChannelManager
*/
TmNettyRemotingClient tmNettyRemotingClient = TmNettyRemotingClient.getInstance(applicationId, transactionServiceGroup, accessKey, secretKey);

/**
* 主要做这么几件事:
* 1. 注册一系列处理器
* 2. 创建定时线程池定时对事务组内的 Server 发起连接,如果连接断开,则尝试重新建立连接
* 3. 如果客户端允许报文批量发送,则创建 mergeSendExecutorService 线程池,并提交 MergedSendRunnable 任务
* 4. 初始化一个定时线程池清理过期的 MessageFuture
* 5. 启动客户端 Bootstrap
* 6. 初始化连接 initConnection
*/
tmNettyRemotingClient.init();
}

启动客户端 Bootstrap 的逻辑如下:

@Override
public void start() {
if (this.defaultEventExecutorGroup == null) {
this.defaultEventExecutorGroup = new DefaultEventExecutorGroup(nettyClientConfig.getClientWorkerThreads(),
new NamedThreadFactory(getThreadPrefix(nettyClientConfig.getClientWorkerThreadPrefix()), nettyClientConfig.getClientWorkerThreads()));
}
this.bootstrap.group(this.eventLoopGroupWorker)
.channel(nettyClientConfig.getClientChannelClazz())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, nettyClientConfig.getConnectTimeoutMillis())
.option(ChannelOption.SO_SNDBUF, nettyClientConfig.getClientSocketSndBufSize())
.option(ChannelOption.SO_RCVBUF, nettyClientConfig.getClientSocketRcvBufSize());
if (nettyClientConfig.enableNative()) {
if (PlatformDependent.isOsx()) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("client run on macOS");
}
} else {
bootstrap.option(EpollChannelOption.EPOLL_MODE, EpollMode.EDGE_TRIGGERED)
.option(EpollChannelOption.TCP_QUICKACK, true);
}
}
bootstrap.handler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new IdleStateHandler(nettyClientConfig.getChannelMaxReadIdleSeconds(),
nettyClientConfig.getChannelMaxWriteIdleSeconds(),
nettyClientConfig.getChannelMaxAllIdleSeconds()))
.addLast(new ProtocolDecoderV1())
.addLast(new ProtocolEncoderV1());
if (channelHandlers != null) {
addChannelPipelineLast(ch, channelHandlers);
}
}
});
if (initialized.compareAndSet(false, true) && LOGGER.isInfoEnabled()) {
LOGGER.info("NettyClientBootstrap has started");
}
}

由于客户端的协议版本根据不同的 Seata 版本是可以确定的,所以这里直接添加了 V1 版本的编解码器,这里 channelHandlers 其实就是 ClientHandler,它也是 Netty 中的一个复合处理器。

RM 的初始化

RM 的初始化大致逻辑和 TM 是类似的,这里就不过多介绍了。

如何发送和处理报文

厘清了 Seata Server 和 Client 的大致启动流程之后,我们就可以深入的看一看 Seata 是如何进行报文发送和处理的。

前面我们也说过了,发送请求和处理报文的核心逻辑是在 AbstractNettyRemoting 中,接下来就看一看这个类。

同步和异步

先简单说一说什么是同步和异步。

同步 Synchronous 和异步 Asynchronous,本质上是描述了程序在处理多个事件或者任务时的不同行为模式。

同步是指一个过程必须等待另一个过程完成之后才能继续进行。换句话说,在同步操作中,调用方发出请求后会一直阻塞等待直到接收到响应结果、或者超时才会继续执行后续代码。

相比之下,异步则允许调用者在请求后不必等待响应就可以向下执行,但当请求完成时,会以某种方式将响应通知到调用者(如通过回调函数、Future),异步模型可以提高并发性和效率。

从另一个角度来说,同步调用需要发起调用的线程获取结果,而异步调用则是由异步线程将结果放到某个地方(Future)或者是异步线程去执行事先准备好的调用成功/失败的回调方法(回调函数)。

下面是一个简单的例子,展示了三种调用方式,同步、异步 Future、异步 Callback。

import lombok.Data;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

public class AsyncTest {

private static final Logger LOGGER = LoggerFactory.getLogger(AsyncTest.class);

public static void main(String[] args) throws InterruptedException, ExecutionException {
Result syncResponse = testSync();
LOGGER.info("同步响应结果: {}", syncResponse.getString());
CompletableFuture<Result> result = testAsyncFuture();
testAsyncCallback();
LOGGER.info("主线程继续向下执行~~");
TimeUnit.SECONDS.sleep(1); // 保证所有结果处理完毕
LOGGER.info("主线程从异步 Future 中获取结果: {}", result.get().getString());
}

public static void testAsyncCallback() {
new AsyncTask().execute(new AsyncCallback() {
@Override
public void onComplete(Result result) {
try {
TimeUnit.MILLISECONDS.sleep(50); // 模拟异步耗时
} catch (InterruptedException e) {
}
LOGGER.info("异步 Callback 获取结果: {}", result.getString());
}
});
}

public static CompletableFuture<Result> testAsyncFuture() {
return CompletableFuture.supplyAsync(() -> {
try {
TimeUnit.MILLISECONDS.sleep(50); // 模拟异步耗时
} catch (InterruptedException e) {
}
Result asyncResponse = getResult();
LOGGER.info("异步 Future 获取结果: {}", asyncResponse.getString());
return asyncResponse;
});
}

public static Result testSync() {
return getResult();
}

@Data
static class Result {
private String string;
}

interface AsyncCallback {
void onComplete(Result result);
}

static class AsyncTask {
void execute(AsyncCallback callback) {
new Thread(() -> {
Result asyncRes = getResult();
callback.onComplete(asyncRes);
}).start();
}
}

private static Result getResult() {
Result result = new Result();
result.setString("结果");
return result;
}
}

输出:

22:26:38.788 [main] INFO  org.hein.netty.AsyncTest - 同步响应结果: 结果
22:26:38.849 [main] INFO org.hein.netty.AsyncTest - 主线程继续向下执行~~
22:26:38.911 [Thread-0] INFO org.hein.netty.AsyncTest - 异步 Callback 获取结果: 结果
22:26:38.911 [ForkJoinPool.commonPool-worker-1] INFO org.hein.netty.AsyncTest - 异步 Future 获取结果: 结果
22:26:39.857 [main] INFO org.hein.netty.AsyncTest - 主线程从异步 Future 中获取结果: 结果

从结果中,至少可以看出三点,

  • 一是异步 Future 和异步 Callback 并不会阻塞主线程向下执行。
  • 二是异步调用时处理结果的不是主线程。
  • 最后,Future 和 Callback 的区别在于 Future 只是由异步线程将结果存储在了一个地方(CompletableFuture#result),但是后续获取结果还是需要主线程(或者其他线程)调用 get 方法,而 Callback 的话,其实就相当于预先设定了结果的处理方式,由异步线程去执行就好了。

当然,CompletableFuture 也是可以作回调的,比如调用 whenComplete 方法。

异步调用

Netty 作为一个高性能的异步 IO 框架,它的设计核心就是异步的,所以基于 Netty 进行异步调用是比较简单的。

protected void sendAsync(Channel channel, RpcMessage rpcMessage) {
channelWritableCheck(channel, rpcMessage.getBody());
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("write message: {}, channel: {}, active? {}, writable? {}, isopen? {}", rpcMessage.getBody(), channel, channel.isActive(), channel.isWritable(), channel.isOpen());
}
doBeforeRpcHooks(ChannelUtil.getAddressFromChannel(channel), rpcMessage);
channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
if (!future.isSuccess()) {
destroyChannel(future.channel());
}
});
}

只需要简单调用 channel 的 writeAndFlush 方法即可实现异步调用。

特别要注意的是,writeAndFlush 方法在调用线程是 EventLoop 线程的情况下会变成同步调用。

同步调用

在 Netty 中实现异步调用很简单,要实现同步调用就麻烦一点,需要将异步调用转换为同步调用。

从本质上来说,异步转同步就是让调用线程发起调用后,拿到响应前进入阻塞,拿到响应后再唤醒它,向下执行。

那么 Seata 的处理的核心就是 MessageFuture 类,如下:

package org.apache.seata.core.protocol;

import org.apache.seata.common.exception.ShouldNeverHappenException;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

public class MessageFuture {

private RpcMessage requestMessage;
private long timeout;
private final long start = System.currentTimeMillis();

private final transient CompletableFuture<Object> origin = new CompletableFuture<>();

public boolean isTimeout() {
return System.currentTimeMillis() - start > timeout;
}

public Object get(long timeout, TimeUnit unit) throws TimeoutException, InterruptedException {
Object result;
try {
result = origin.get(timeout, unit);
if (result instanceof TimeoutException) {
throw (TimeoutException) result;
}
} catch (ExecutionException e) {
throw new ShouldNeverHappenException("Should not get results in a multi-threaded environment", e);
} catch (TimeoutException e) {
throw new TimeoutException(String.format("%s, cost: %d ms", e.getMessage(), System.currentTimeMillis() - start));
}
if (result instanceof RuntimeException) {
throw (RuntimeException) result;
} else if (result instanceof Throwable) {
throw new RuntimeException((Throwable) result);
}
return result;
}

public void setResultMessage(Object obj) {
origin.complete(obj);
}

public RpcMessage getRequestMessage() { return requestMessage; }

public void setRequestMessage(RpcMessage requestMessage) { this.requestMessage = requestMessage;}

public long getTimeout() { return timeout; }

public void setTimeout(long timeout) { this.timeout = timeout;}
}

有了这个类之后,同步调用的过程如下,我们以客户端请求、服务端响应为例:

  • 首先客户端将请求构建为 MessageFuture,然后将请求 id 和这个 MessageFuture 对象存储到一个哈希表中。
  • 接着客户端调用 channel.writeAndFlush 发起异步调用,是的,这里还是异步。
  • 异步转同步的核心在于,此时线程需要调用 MessageFuture 对象的 get 方法进入阻塞,当然实际是调用了 CompletableFuture 的 get 方法进入同步阻塞。
  • 当服务端处理完毕,它又会发出请求(服务端视角),在客户端来看,这就是响应。
  • 当客户端收到响应之后,由 EventLoop 线程将响应结果设置到 MessageFuture 中,由于一次请求和响应的 id 是相同的,所以可以从上面的哈希表中拿到对应的 MessageFuture 对象。
  • 当响应结果被设置之后,上面阻塞的线程就可以恢复运行,这样就实现了同步的效果。

所以,Seata 的解决方案本质上来说就是利用了 CompletableFuture 对象,将它作为一个存储结果的容器。

protected Object sendSync(Channel channel, RpcMessage rpcMessage, long timeoutMillis) throws TimeoutException {
if (timeoutMillis <= 0) {
throw new FrameworkException("timeout should more than 0ms");
}
if (channel == null) {
LOGGER.warn("sendSync nothing, caused by null channel.");
return null;
}
MessageFuture messageFuture = new MessageFuture();
messageFuture.setRequestMessage(rpcMessage);
messageFuture.setTimeout(timeoutMillis);
futures.put(rpcMessage.getId(), messageFuture); // 请求和响应的 id 是一样的
// 检查该 Channel 是否可写(Channel 中有写缓冲区,如果缓冲区达到阈值水位,则不可写)
channelWritableCheck(channel, rpcMessage.getBody());
// 获取目的 ip 地址
String remoteAddr = ChannelUtil.getAddressFromChannel(channel);
// 执行发送前钩子方法
doBeforeRpcHooks(remoteAddr, rpcMessage);
// 发送结果,并设置回调,非阻塞
channel.writeAndFlush(rpcMessage).addListener((ChannelFutureListener) future -> {
// 发送失败,移除 future,关闭 Channel
if (!future.isSuccess()) {
MessageFuture mf = futures.remove(rpcMessage.getId());
if (mf != null) {
mf.setResultMessage(future.cause());
}
destroyChannel(future.channel());
}
});
try {
// Netty 是异步发送,所以这里需要等待结果,将异步转为同步
Object result = messageFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
// 执行发送后的钩子方法
doAfterRpcHooks(remoteAddr, rpcMessage, result);
return result;
} catch (Exception exx) {
LOGGER.error("wait response error:{},ip:{},request:{}", exx.getMessage(), channel.remoteAddress(), rpcMessage.getBody());
// 超时异常
if (exx instanceof TimeoutException) {
throw (TimeoutException) exx;
} else {
throw new RuntimeException(exx);
}
}
}

报文处理

在 Netty 中,提到报文处理,我们首先应该想到的就是入栈、出栈处理器。

在 Seata Server 端,除了常见的编解码处理器之外,就是 ServerHandler 处理器了,如下:

@ChannelHandler.Sharable
class ServerHandler extends ChannelDuplexHandler {

@Override
public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
// 前置了解码处理器,所以这里的消息是 RpcMessage
if (msg instanceof RpcMessage) {
processMessage(ctx, (RpcMessage) msg);
} else {
LOGGER.error("rpcMessage type error");
}
}

// ...
}

比较有业务含义的就是这个 channelRead 方法,所有发向 Server 的报文在经过解码之后都会来到这个方法。

这里的 processMessage 方法就是 AbstractNettyRemoting 中的业务处理方法,如下:

protected void processMessage(ChannelHandlerContext ctx, RpcMessage rpcMessage) throws Exception {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("{} msgId: {}, body: {}", this, rpcMessage.getId(), rpcMessage.getBody());
}
Object body = rpcMessage.getBody();
if (body instanceof MessageTypeAware) {
MessageTypeAware messageTypeAware = (MessageTypeAware) body;
// 在 Server 启动的时候,向 processorTable 注册了一大堆处理器
final Pair<RemotingProcessor, ExecutorService> pair = this.processorTable.get((int) messageTypeAware.getTypeCode());
if (pair != null) {
// 拿到对应的线程池执行
if (pair.getSecond() != null) {
try {
pair.getSecond().execute(() -> {
try {
// 找对应的处理器执行
pair.getFirst().process(ctx, rpcMessage);
} catch (Throwable th) {
LOGGER.error(FrameworkErrorCode.NetDispatch.getErrCode(), th.getMessage(), th);
} finally {
MDC.clear();
}
});
} catch (RejectedExecutionException e) {
// 线程池满了,执行拒绝策略
LOGGER.error(FrameworkErrorCode.ThreadPoolFull.getErrCode(), "thread pool is full, current max pool size is " + messageExecutor.getActiveCount());
if (allowDumpStack) {
// 导出线程栈信息
String name = ManagementFactory.getRuntimeMXBean().getName();
String pid = name.split("@")[0];
long idx = System.currentTimeMillis();
try {
String jstackFile = idx + ".log";
LOGGER.info("jstack command will dump to {}", jstackFile);
Runtime.getRuntime().exec(String.format("jstack %s > %s", pid, jstackFile));
} catch (IOException exx) {
LOGGER.error(exx.getMessage());
}
allowDumpStack = false;
}
}
} else {
try {
// 如果没有为处理器配置线程池,则由当前线程执行,基本上就是 EventLoop 线程了
pair.getFirst().process(ctx, rpcMessage);
} catch (Throwable th) {
LOGGER.error(FrameworkErrorCode.NetDispatch.getErrCode(), th.getMessage(), th);
}
}
} else {
LOGGER.error("This message type [{}] has no processor.", messageTypeAware.getTypeCode());
}
} else {
LOGGER.error("This rpcMessage body[{}] is not MessageTypeAware type.", body);
}
}

这个方法的逻辑很简单。

Seata 在 Server 启动的过程中,向 processorTable 注册了一大堆处理器,那么这里就可以根据消息类型 Code 拿到对应的处理器和线程池。

如果有线程池,就在线程池内执行处理器的方法,否则就交给 EventLoop 线程去执行。

当然,对于 Client 而言,也是这样的。

批量发送

在网络程序中,有时候也需要实现批量发送,我们来看 Seata 是怎么做的,这里主要看客户端向服务端发送。

还记得我们上面在 Client 启动的过程中提到过一个线程池 mergeSendExecutorService,如果允许批量发送,那么在 Client 启动的时候就会提交一个 MergedSendRunnable 任务,我们先来看这个任务在干啥?

private class MergedSendRunnable implements Runnable {

@Override
public void run() {
// 死循环
while (true) {
synchronized (mergeLock) {
try {
// 保证线程最多只会空闲 1ms
mergeLock.wait(MAX_MERGE_SEND_MILLS); // 1
} catch (InterruptedException ignore) {
// ignore
}
}
// 正在发送中的标识
isSending = true;
// basketMap: key 是 address,value 是发向该 address 的报文队列(阻塞队列)
basketMap.forEach((address, basket) -> {
if (basket.isEmpty()) {
return;
}
MergedWarpMessage mergeMessage = new MergedWarpMessage();
while (!basket.isEmpty()) {
// 将同一个阻塞队列中所有 RpcMessage 进行合并
RpcMessage msg = basket.poll();
mergeMessage.msgs.add((AbstractMessage) msg.getBody());
mergeMessage.msgIds.add(msg.getId());
}
if (mergeMessage.msgIds.size() > 1) {
printMergeMessageLog(mergeMessage);
}
Channel sendChannel = null;
try {
// 批量发送报文是一个同步请求,但是无需获取返回值
// 因为 messageFuture 在将报文放入 basketMap 之前就已经被创建
// 返回值将在 ClientOnResponseProcessor 中被设置
sendChannel = clientChannelManager.acquireChannel(address);
// 内部将 mergeMessage 封装为一个普通的 RpcMessage 发送
AbstractNettyRemotingClient.this.sendAsyncRequest(sendChannel, mergeMessage);
} catch (FrameworkException e) {
if (e.getErrorCode() == FrameworkErrorCode.ChannelIsNotWritable && sendChannel != null) {
destroyChannel(address, sendChannel);
}
// fast fail
for (Integer msgId : mergeMessage.msgIds) {
MessageFuture messageFuture = futures.remove(msgId);
if (messageFuture != null) {
messageFuture.setResultMessage(new RuntimeException(String.format("%s is unreachable", address), e));
}
}
LOGGER.error("client merge call failed: {}", e.getMessage(), e);
}
});
isSending = false;
}
}
}

那么,与之相关的批量发送代码如下:

public Object sendSyncRequest(Object msg) throws TimeoutException {
String serverAddress = loadBalance(getTransactionServiceGroup(), msg);
long timeoutMillis = this.getRpcRequestTimeout();
RpcMessage rpcMessage = buildRequestMessage(msg, ProtocolConstants.MSGTYPE_RESQUEST_SYNC);
// send batch message
// put message into basketMap, @see MergedSendRunnable
if (this.isEnableClientBatchSendRequest()) {
// 如果允许客户端批量消息发送
// send batch message is sync request, needs to create messageFuture and put it in futures.
MessageFuture messageFuture = new MessageFuture();
messageFuture.setRequestMessage(rpcMessage);
messageFuture.setTimeout(timeoutMillis);
futures.put(rpcMessage.getId(), messageFuture);

// put message into basketMap
// 拿到 serverAddress 对应的发送队列
BlockingQueue<RpcMessage> basket = CollectionUtils.computeIfAbsent(basketMap, serverAddress,
key -> new LinkedBlockingQueue<>());
// 将报文添加到队列中,等待 mergeSendExecutorService 进行实际的发送
if (!basket.offer(rpcMessage)) {
LOGGER.error("put message into basketMap offer failed, serverAddress: {}, rpcMessage: {}", serverAddress, rpcMessage);
return null;
}
if (!isSending) {
// 保证队列中一有数据,就唤醒线程,进行批量发送
synchronized (mergeLock) {
mergeLock.notifyAll();
}
}
try {
// 线程阻塞等待响应
return messageFuture.get(timeoutMillis, TimeUnit.MILLISECONDS);
} catch (Exception exx) {
LOGGER.error("wait response error: {}, ip: {}, request: {}", exx.getMessage(), serverAddress, rpcMessage.getBody());
if (exx instanceof TimeoutException) {
throw (TimeoutException) exx;
} else {
throw new RuntimeException(exx);
}
}
} else {
// 普通发送,拿到 channel 调父类的同步调用方法即可
Channel channel = clientChannelManager.acquireChannel(serverAddress);
return super.sendSync(channel, rpcMessage, timeoutMillis);
}
}

可以看到,这里面也用到了对象锁的同步-等待机制,那么实现的效果就是:

  1. 最多隔 1ms 会遍历 basketMap 进行报文发送。
  2. 在 mergeSendExecutorService 内部的线程阻塞期间(mainLock.wait),如果来了需要发送的报文,那么会唤醒 mainLock 上的线程,继续进行发送。

那 Server 是怎么处理的呢?主要看 MergedWarpMessage 报文的 TypeCode,实际上就是 TYPE_SEATA_MERGE,再看 Server 启动的时候对这个 Code 注册哪个处理器,实际上就是 ServerOnRequestProcessor。

这里其实就向你展示了,如何去找某个报文是怎么处理的,授人以鱼不如授人以渔!

在 ServerOnRequestProcessor 这边,实际上对应了两种处理 MergedWarpMessage 报文的方式:

  1. MergedWarpMessage 中的所有独立请求全部处理完毕之后,统一发送 MergeResultMessage。
  2. 由 batchResponseExecutorService 线程池处理发送任务,可以保证两点,一是当有报文结果就响应,即使线程 wait,也会将它 notify,二是至少 1ms 会响应一次,因为 batchResponseExecutorService 中执行的线程最多 wait 1ms。

注意,这两种方式响应的报文类型是不同的,第一种响应的是 MergeResultMessage,第二种是 BatchResultMessage,在 Client 也会有不同的处理。

ServerOnRequestProcessor 中核心处理方法如下:

private void onRequestMessage(ChannelHandlerContext ctx, RpcMessage rpcMessage) {
Object message = rpcMessage.getBody();
RpcContext rpcContext = ChannelManager.getContextFromIdentified(ctx.channel());
// the batch send request message
if (message instanceof MergedWarpMessage) {
final List<AbstractMessage> msgs = ((MergedWarpMessage) message).msgs;
final List<Integer> msgIds = ((MergedWarpMessage) message).msgIds;
// 允许 TC 服务端批量返回结果 && 客户端版本号 >= 1.5.0
if (NettyServerConfig.isEnableTcServerBatchSendResponse() && StringUtils.isNotBlank(rpcContext.getVersion())
&& Version.isAboveOrEqualVersion150(rpcContext.getVersion())) {
// 由 batchResponseExecutorService 单独处理,无需等到批量请求全部处理完毕
for (int i = 0; i < msgs.size(); i++) {
if (PARALLEL_REQUEST_HANDLE) {
int finalI = i;
CompletableFuture.runAsync(
() -> handleRequestsByMergedWarpMessageBy150(msgs.get(finalI), msgIds.get(finalI), rpcMessage, ctx, rpcContext));
} else {
handleRequestsByMergedWarpMessageBy150(msgs.get(i), msgIds.get(i), rpcMessage, ctx, rpcContext);
}
}
} else {
// 每个请求都处理完毕,才能向客户端发出响应
List<AbstractResultMessage> results = new ArrayList<>();
List<CompletableFuture<AbstractResultMessage>> futures = new ArrayList<>();
for (int i = 0; i < msgs.size(); i++) {
if (PARALLEL_REQUEST_HANDLE) {
int finalI = i;
futures.add(CompletableFuture.supplyAsync(() -> handleRequestsByMergedWarpMessage(msgs.get(finalI), rpcContext)));
} else {
results.add(i, handleRequestsByMergedWarpMessage(msgs.get(i), rpcContext));
}
}
if (CollectionUtils.isNotEmpty(futures)) {
try {
for (CompletableFuture<AbstractResultMessage> future : futures) {
results.add(future.get()); // 阻塞等待处理结果
}
} catch (InterruptedException | ExecutionException e) {
LOGGER.error("handle request error: {}", e.getMessage(), e);
}
}
MergeResultMessage resultMessage = new MergeResultMessage();
resultMessage.setMsgs(results.toArray(new AbstractResultMessage[0]));
remotingServer.sendAsyncResponse(rpcMessage, ctx.channel(), resultMessage);
}
} else {
// 处理单个报文响应
}
}

而 handleRequestsByMergedWarpMessage 和 handleRequestsByMergedWarpMessageBy150 的区别就在于后者会将结果封装为 QueueItem 加入到阻塞队列由 batchResponseExecutorService 中的线程进行实际的发送,而前者仅仅是返回处理的结果。

private AbstractResultMessage handleRequestsByMergedWarpMessage(AbstractMessage subMessage, RpcContext rpcContext) {
AbstractResultMessage resultMessage = transactionMessageHandler.onRequest(subMessage, rpcContext);
return resultMessage;
}

private void handleRequestsByMergedWarpMessageBy150(AbstractMessage msg, int msgId, RpcMessage rpcMessage,
ChannelHandlerContext ctx, RpcContext rpcContext) {
AbstractResultMessage resultMessage = transactionMessageHandler.onRequest(msg, rpcContext);
// 拿到 channel 对应的发送队列
BlockingQueue<QueueItem> msgQueue = CollectionUtils.computeIfAbsent(basketMap, ctx.channel(), key -> new LinkedBlockingQueue<>());
// 将结果添加到队列中,等待 batchResponseExecutorService 线程池实际进行发送
if (!msgQueue.offer(new QueueItem(resultMessage, msgId, rpcMessage))) {
LOGGER.error("put message into basketMap offer failed, channel: {}, rpcMessage: {}, resultMessage: {}", ctx.channel(), rpcMessage, resultMessage);
}
if (!isResponding) {
// 保证队列中一有数据,就唤醒线程,进行批量发送
synchronized (batchResponseLock) {
batchResponseLock.notifyAll();
}
}
}

再来看 batchResponseExecutorService 线程池是怎么处理批量发送的任务的?

private class BatchResponseRunnable implements Runnable {
@Override
public void run() {
while (true) {
synchronized (batchResponseLock) {
try {
// 最多空闲 1ms
batchResponseLock.wait(MAX_BATCH_RESPONSE_MILLS);
} catch (InterruptedException e) {
LOGGER.error("BatchResponseRunnable Interrupted error", e);
}
}
isResponding = true;
// 遍历 basketMap 处理
basketMap.forEach((channel, msgQueue) -> {
if (msgQueue.isEmpty()) {
return;
}
// Because the [serialization,compressor,rpcMessageId,headMap] of the response
// needs to be the same as the [serialization,compressor,rpcMessageId,headMap] of the request.
// Assemble by grouping according to the [serialization,compressor,rpcMessageId,headMap] dimensions.
// 将队列中的响应封装为 BatchResultMessage,但是注意并不是将所有的响应报文一次发送出去
// 需要按照 [serialization,compressor,rpcMessageId,headMap] 进行分组,然后按组进行异步发送
Map<ClientRequestRpcInfo, BatchResultMessage> batchResultMessageMap = new HashMap<>();
while (!msgQueue.isEmpty()) {
QueueItem item = msgQueue.poll();
BatchResultMessage batchResultMessage = CollectionUtils.computeIfAbsent(batchResultMessageMap,
new ClientRequestRpcInfo(item.getRpcMessage()),
key -> new BatchResultMessage());
batchResultMessage.getResultMessages().add(item.getResultMessage());
batchResultMessage.getMsgIds().add(item.getMsgId());
}
batchResultMessageMap.forEach((clientRequestRpcInfo, batchResultMessage) ->
remotingServer.sendAsyncResponse(buildRpcMessage(clientRequestRpcInfo), channel, batchResultMessage));
});
isResponding = false;
}
}
}

最后我们来看 Client 这边是怎么处理 Server 的批量响应报文的,根据 Client 注册的处理器,处理批量报文的处理器是 ClientOnResponseProcessor,如下:

public void process(ChannelHandlerContext ctx, RpcMessage rpcMessage) throws Exception {
// 处理 MergeResultMessage
if (rpcMessage.getBody() instanceof MergeResultMessage) {
MergeResultMessage results = (MergeResultMessage) rpcMessage.getBody();
MergedWarpMessage mergeMessage = (MergedWarpMessage) mergeMsgMap.remove(rpcMessage.getId());
for (int i = 0; i < mergeMessage.msgs.size(); i++) {
int msgId = mergeMessage.msgIds.get(i);
MessageFuture future = futures.remove(msgId);
if (future == null) {
LOGGER.error("msg: {} is not found in futures, result message: {}", msgId, results.getMsgs()[i]);
} else {
future.setResultMessage(results.getMsgs()[i]);
}
}
} else if (rpcMessage.getBody() instanceof BatchResultMessage) {
// 处理 BatchResultMessage
try {
BatchResultMessage batchResultMessage = (BatchResultMessage) rpcMessage.getBody();
for (int i = 0; i < batchResultMessage.getMsgIds().size(); i++) {
int msgId = batchResultMessage.getMsgIds().get(i);
MessageFuture future = futures.remove(msgId);
if (future == null) {
LOGGER.error("msg: {} is not found in futures, result message: {}", msgId, batchResultMessage.getResultMessages().get(i));
} else {
future.setResultMessage(batchResultMessage.getResultMessages().get(i));
}
}
} finally {
// In order to be compatible with the old version, in the batch sending of version 1.5.0,
// batch messages will also be placed in the local cache of mergeMsgMap,
// but version 1.5.0 no longer needs to obtain batch messages from mergeMsgMap
mergeMsgMap.clear();
}
} else {
// 处理非批量发送报文
MessageFuture messageFuture = futures.remove(rpcMessage.getId());
if (messageFuture != null) {
messageFuture.setResultMessage(rpcMessage.getBody());
} else {
if (rpcMessage.getBody() instanceof AbstractResultMessage) {
if (transactionMessageHandler != null) {
transactionMessageHandler.onResponse((AbstractResultMessage) rpcMessage.getBody(), null);
}
}
}
}
}

当然,这里处理的逻辑很简单,就是将结果塞到对应的 MessageFuture 中,那么最开始发送请求的、阻塞的线程就可以拿到结果了,这样一次批量发送和响应就算处理完毕了。

我们再做一些额外的思考,Seata 的批量发送为什么有两种方式,孰优孰劣?

对于 MergeResultMessage 的这种方式来说,它必须等到所有的报文都处理完毕之后才会发送出去,所以其实它的响应速度受限于处理最长时间的报文,即使其他报文在很短时间内就可以发送出去。

而 BatchResultMessage 这种方式则不然,配合 CompletableFuture 进行并行处理,它就可以实现一有报文处理完毕就发送,而不需要等其他报文的处理,它的响应速度肯定是更快的。

而后面这种方式是 Seata 1.5 版本之后才有的,其实也可以看出来这是一种更好地处理方式。

最后,再分享一张 Seata RPC 重构作者的全局事务提交请求的交互流程图:

image-20241217222048505

Seata 如何管理 Channel

在整个 TC、TM、RM 的网络通信的过程中,Channel 是一个至关重要的通信组件,而要想知道 Seata 是怎么管理 Channel 的,最容易想到的入口就是看 Server 和 Client 发送报文时是从哪里拿到到 Channel 的。

在 AbstractNettyRemotingClient 类的 sendSyncRequest 中,我们可以看到下面的代码:

public Object sendSyncRequest(Object msg) throws TimeoutException {
// ...
// Client 通过 NettyClientChannelManager 获取 Channel
Channel channel = clientChannelManager.acquireChannel(serverAddress);
return super.sendSync(channel, rpcMessage, timeoutMillis);
}

而在 AbstractNettyRemotingServer 类的 sendSyncRequest 中,我们可以看到下面的代码:

public Object sendSyncRequest(String resourceId, String clientId, Object msg, boolean tryOtherApp) throws TimeoutException {
// Server 通过 ChannelManager 拿到 Channel
Channel channel = ChannelManager.getChannel(resourceId, clientId, tryOtherApp);
if (channel == null) {
throw new RuntimeException("rm client is not connected. dbkey:" + resourceId + ",clientId:" + clientId);
}
RpcMessage rpcMessage = buildRequestMessage(msg, ProtocolConstants.MSGTYPE_RESQUEST_SYNC);
return super.sendSync(channel, rpcMessage, NettyServerConfig.getRpcRequestTimeout());
}

所以 Client 主要是通过 NettyClientChannelManager 中获取 Channel,而 Server 则是根据 resourceId 和 clientId 从 ChannelManager 中获取 Channel。

所以下面我们主要研究的就是这两个类,以及相关的一些逻辑。

Client Channel

我们先来看 Client 这边是怎么管理 Channel 的,核心类是 NettyClientChannelManager。

先简单看一下这个类的属性,

// serverAddress -> lock
private final ConcurrentMap<String, Object> channelLocks = new ConcurrentHashMap<>();
// serverAddress -> NettyPoolKey
private final ConcurrentMap<String, NettyPoolKey> poolKeyMap = new ConcurrentHashMap<>();
// serverAddress -> Channel
private final ConcurrentMap<String, Channel> channels = new ConcurrentHashMap<>();
// 对象池,NettyPoolKey -> Channel
private final GenericKeyedObjectPool<NettyPoolKey, Channel> nettyClientKeyPool;
// 函数式接口,封装了通过 serverAddress 获取 NettyPoolKey 的逻辑
private final Function<String, NettyPoolKey> poolKeyFunction;

对象池的核心类

Seata 使用了 GenericKeyedObjectPool 作为管理 Channel 的对象池。

GenericKeyedObjectPool 作为 Apache Commons Pool 库中的一个实现,它主要用于管理一组对象池,每个对象通过唯一的 Key 进行区分,可以支持多类型的对象池化需求。

在使用 GenericKeyedObjectPool 时,通常还需要配置 KeyedPoolableObjectFactory 工厂,这个工厂定义了如何创建、验证、激活、钝化以及销毁池中的对象。

当 GenericKeyedObjectPool 需要创建对象时会调用 KeyedPoolableObjectFactory 工厂的 makeObject 方法,当需要销毁时会调用 destroyObject 方法进行销毁 ……

如何池化 Channel

被池化的对象就是 Channel,而对应的 Key 是 NettyPoolKey,如下:

public class NettyPoolKey {

private TransactionRole transactionRole;
private String address;
private AbstractMessage message;

// ...
}

在 NettyPoolKey 中,维护了三个信息,事务角色(TM、RM、Server),目的 TC Server 地址,以及在 Client 连接 Server 时发送的 RPC 报文。

如何创建这个 NettyPoolKey 呢?在 Seata 中,客户端其实是有两种角色的,TM 和 RM,创建的逻辑肯定是不一样的,所以,Seata 在 AbstractNettyRemotingClient 中抽象了一个方法,它的返回值是一个函数式接口,这个函数式接口就封装了根据 serverAddress 创建 NettyPoolKey 的逻辑。

// org.apache.seata.core.rpc.netty.AbstractNettyRemotingClient#getPoolKeyFunction
protected abstract Function<String, NettyPoolKey> getPoolKeyFunction();

比如在 TM 中的实现是:

protected Function<String, NettyPoolKey> getPoolKeyFunction() {
return severAddress -> {
RegisterTMRequest message = new RegisterTMRequest(applicationId, transactionServiceGroup, getExtraData());
return new NettyPoolKey(NettyPoolKey.TransactionRole.TM_ROLE, severAddress, message);
};
}

而在 RM 中的实现是:

protected Function<String, NettyPoolKey> getPoolKeyFunction() {
return serverAddress -> {
String resourceIds = getMergedResourceKeys();
if (resourceIds != null && LOGGER.isInfoEnabled()) {
LOGGER.info("RM will register: {}", resourceIds);
}
RegisterRMRequest message = new RegisterRMRequest(applicationId, transactionServiceGroup);
message.setResourceIds(resourceIds);
return new NettyPoolKey(NettyPoolKey.TransactionRole.RM_ROLE, serverAddress, message);
};
}

从这里就可以看到,TM 在连接 Server 后发送的报文是 RegisterTMRequest,而 RM 是 RegisterRMRequest。

那这个函数式接口在什么时候被调用呢,后面再看。

我们前面也说到了,一个对象池,会配备对应的对象创建工厂 KeyedPoolableObjectFactory,在 Seata 中,以 NettyPoolableFactory 继承 KeyedPoolableObjectFactory 来实现。

/**
* Netty Channel 创建工厂,通过 NettyPoolKey 创建 Channel,该类的方法必须是线程安全的
*/
public class NettyPoolableFactory implements KeyedPoolableObjectFactory<NettyPoolKey, Channel> {

// ...

/**
* 需要一个新的实例则调用该方法
*/
@Override
public Channel makeObject(NettyPoolKey key) {
InetSocketAddress address = NetUtil.toInetSocketAddress(key.getAddress());
// 创建 Channel,本质上就是通过 bootstrap.connect 连接到 Seata Server 返回 Channel
Channel tmpChannel = clientBootstrap.getNewChannel(address);
long start = System.currentTimeMillis();
Object response;
Channel channelToServer = null;
if (key.getMessage() == null) {
throw new FrameworkException("register msg is null, role:" + key.getTransactionRole().name());
}
try {
// 发送 Message,TM 就是 RegisterTMRequest,RM 就是 RegisterRMRequest
response = rpcRemotingClient.sendSyncRequest(tmpChannel, key.getMessage());
// 根据 response 判断是否注册成功
if (!isRegisterSuccess(response, key.getTransactionRole())) {
rpcRemotingClient.onRegisterMsgFail(key.getAddress(), tmpChannel, response, key.getMessage());
} else {
// 注册成功
channelToServer = tmpChannel;
// 将 serverAddress 作为 key,Channel 作为 value,添加到 NettyClientChannelManager.channels 中
// 如果是 RM 可能还需要将 Server 注册 resources
rpcRemotingClient.onRegisterMsgSuccess(key.getAddress(), tmpChannel, response, key.getMessage());
}
} catch (Exception exx) {
if (tmpChannel != null) {
tmpChannel.close();
}
throw new FrameworkException("register " + key.getTransactionRole().name() + " error, errMsg:" + exx.getMessage());
}
return channelToServer;
}

// ...

@Override
public void destroyObject(NettyPoolKey key, Channel channel) throws Exception {
if (channel != null) {
channel.disconnect();
channel.close();
}
}

/**
* 需要借用对象时会调用该方法校验对象有效性(可选)
*/
@Override
public boolean validateObject(NettyPoolKey key, Channel obj) {
if (obj != null && obj.isActive()) {
return true;
}
return false;
}

/**
* 需要借用对象时会调用该方法激活对象
*/
@Override
public void activateObject(NettyPoolKey key, Channel obj) throws Exception {}

/**
* 归还对象时会调用该方法钝化对象
*/
@Override
public void passivateObject(NettyPoolKey key, Channel obj) throws Exception {}
}

获取 Channel

在整个 Seata 客户端,有三个口径可以获取 Channel,即初始化、定时重连,发送报文时获取 Channel。

// 口径一
private void initConnection() {
boolean failFast =
ConfigurationFactory.getInstance().getBoolean(ConfigurationKeys.ENABLE_TM_CLIENT_CHANNEL_CHECK_FAIL_FAST, DefaultValues.DEFAULT_CLIENT_CHANNEL_CHECK_FAIL_FAST);
getClientChannelManager().initReconnect(transactionServiceGroup, failFast);
}

// 口径二
public void init() {
// 默认延时 60s 定时 10s 周期重连
timerExecutor.scheduleAtFixedRate(() -> {
try {
clientChannelManager.reconnect(getTransactionServiceGroup());
} catch (Exception ex) {
LOGGER.warn("reconnect server failed. {}", ex.getMessage());
}
}, SCHEDULE_DELAY_MILLS, SCHEDULE_INTERVAL_MILLS, TimeUnit.MILLISECONDS);
// ...
}

// 口径三
public Object sendSyncRequest(Object msg) throws TimeoutException {
// ...
// Client 通过 NettyClientChannelManager 获取 Channel
Channel channel = clientChannelManager.acquireChannel(serverAddress);
return super.sendSync(channel, rpcMessage, timeoutMillis);
}

不过,这三个口径最后都会调用到 clientChannelManager 的 acquireChannel 方法获取 Channel。

/**
* 根据 serverAddress 拿到 Channel,如果 Channel 不存在或者连接已死则需要重新建立连接
*/
Channel acquireChannel(String serverAddress) {
// 从 channels 中根据 serverAddress 拿到 Channel
Channel channelToServer = channels.get(serverAddress);
if (channelToServer != null) {
channelToServer = getExistAliveChannel(channelToServer, serverAddress);
if (channelToServer != null) {
return channelToServer;
}
}
// 如果 channels 没有这个 Channel 或者这个 Channel 已死,则需要对这个地址建立连接
Object lockObj = CollectionUtils.computeIfAbsent(channelLocks, serverAddress, key -> new Object());
synchronized (lockObj) {
// 建立连接
return doConnect(serverAddress);
}
}

private Channel doConnect(String serverAddress) {
// 再尝试拿一次
Channel channelToServer = channels.get(serverAddress);
if (channelToServer != null && channelToServer.isActive()) {
return channelToServer;
}
Channel channelFromPool;
try {
// 这里就调用了函数式接口
NettyPoolKey currentPoolKey = poolKeyFunction.apply(serverAddress);
poolKeyMap.put(serverAddress, currentPoolKey);
// 从对象池中 borrowObject,如果需要创建对象,则会调用工厂的 makeObject 方法,
// 该方法内部就会向 Server 进行 connect,并且发送 currentPoolKey.message 的报文
channelFromPool = nettyClientKeyPool.borrowObject(currentPoolKey);
channels.put(serverAddress, channelFromPool);
} catch (Exception exx) {
LOGGER.error("{} register RM failed.", FrameworkErrorCode.RegisterRM.getErrCode(), exx);
throw new FrameworkException("can not register RM,err:" + exx.getMessage());
}
return channelFromPool;
}

Server Channel

而在 Server 这边,基本上有关 Channe 管理的核心逻辑都在 ChannelManager 中,那 Server 这边的 Channel 是怎么来的呢?还记得在 Client 那边向 Server 发起连接,成功之后还会发送 TM 和 RM 的一个注册请求。

这里先来看看 Server 是怎么处理这些 registerRequest 的。

处理 Client 注册

与之相关的处理器是 RegRmProcessor 和 RegTmProcessor,在这两个处理器中,最核心的逻辑就是调用 ChannelManager 的 registerTMChannel 和 registerRMChannel 方法。

public static void registerTMChannel(RegisterTMRequest request, Channel channel) throws IncompatibleVersionException {
// 构建 RpcContext,这个 RpcContext 就是维护了客户端连接信息上下文
RpcContext rpcContext = buildChannelHolder(NettyPoolKey.TransactionRole.TM_ROLE, request.getVersion(),
request.getApplicationId(),
request.getTransactionServiceGroup(),
null, channel);
// 将 Channel 作为 key,rpcContext 作为 value,put 到 IDENTIFIED_CHANNELS 中
rpcContext.holdInIdentifiedChannels(IDENTIFIED_CHANNELS);
// applicationId:clientIp
String clientIdentified = rpcContext.getApplicationId() + Constants.CLIENT_ID_SPLIT_CHAR + ChannelUtil.getClientIpFromChannel(channel);
// 将 Channel 信息存储到 TM_CHANNELS 中
ConcurrentMap<Integer, RpcContext> clientIdentifiedMap = CollectionUtils.computeIfAbsent(TM_CHANNELS, clientIdentified, key -> new ConcurrentHashMap<>());
rpcContext.holdInClientChannels(clientIdentifiedMap);
}

public static void registerRMChannel(RegisterRMRequest resourceManagerRequest, Channel channel) throws IncompatibleVersionException {
Set<String> dbkeySet = dbKeytoSet(resourceManagerRequest.getResourceIds());
RpcContext rpcContext;
if (!IDENTIFIED_CHANNELS.containsKey(channel)) {
// 构建 RpcContext 和 IDENTIFIED_CHANNELS
rpcContext = buildChannelHolder(NettyPoolKey.TransactionRole.RM_ROLE, resourceManagerRequest.getVersion(),
resourceManagerRequest.getApplicationId(), resourceManagerRequest.getTransactionServiceGroup(),
resourceManagerRequest.getResourceIds(), channel);
rpcContext.holdInIdentifiedChannels(IDENTIFIED_CHANNELS);
} else {
rpcContext = IDENTIFIED_CHANNELS.get(channel);
rpcContext.addResources(dbkeySet);
}
if (dbkeySet == null || dbkeySet.isEmpty()) {
return;
}
for (String resourceId : dbkeySet) {
String clientIp;
// 维护 RM_CHANNELS 信息
ConcurrentMap<Integer, RpcContext> portMap = CollectionUtils.computeIfAbsent(RM_CHANNELS, resourceId, key -> new ConcurrentHashMap<>())
.computeIfAbsent(resourceManagerRequest.getApplicationId(), key -> new ConcurrentHashMap<>())
.computeIfAbsent(clientIp = ChannelUtil.getClientIpFromChannel(channel), key -> new ConcurrentHashMap<>());
rpcContext.holdInResourceManagerChannels(resourceId, portMap);
updateChannelsResource(resourceId, clientIp, resourceManagerRequest.getApplicationId());
}
}

这两个方法逻辑很简单,就是基于注册请求和 Channel 的信息构建 RpcContext,维护 Server 内的相关 Map 集合,IDENTIFIED_CHANNELS、RM_CHANNELS、TM_CHANNELS。

但是,说实话,这几个集合实在是嵌套的有点深,不知道能不能优化一下。

/**
* Channel -> RpcContext
*/
private static final ConcurrentMap<Channel, RpcContext> IDENTIFIED_CHANNELS = new ConcurrentHashMap<>();

/**
* resourceId -> applicationId -> ip -> port -> RpcContext
*/
// resourceId applicationId ip
private static final ConcurrentMap<String, ConcurrentMap<String, ConcurrentMap<String,
// port RpcContext
ConcurrentMap<Integer, RpcContext>>>> RM_CHANNELS = new ConcurrentHashMap<>();

/**
* applicationId:clientIp -> port -> RpcContext
*/
private static final ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>> TM_CHANNELS = new ConcurrentHashMap<>();

获取 Channel

在 Server 这边,获取 Channel 的逻辑,真的是超长,感兴趣自己看看吧,本质上就是从 map 中拿到一个有效的 Channel。

public static Channel getChannel(String resourceId, String clientId, boolean tryOtherApp) {
Channel resultChannel = null;
// 解析 ClientId,三部分组成:applicationId + clientIp + clientPort
String[] clientIdInfo = parseClientId(clientId);
if (clientIdInfo == null || clientIdInfo.length != 3) {
throw new FrameworkException("Invalid Client ID: " + clientId);
}
if (StringUtils.isBlank(resourceId)) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("No channel is available, resourceId is null or empty");
}
return null;
}
// applicationId
String targetApplicationId = clientIdInfo[0];
// clientIp
String targetIP = clientIdInfo[1];
// clientPort
int targetPort = Integer.parseInt(clientIdInfo[2]);
// 下面就是不断取出内层的 ConcurrentHashMap
ConcurrentMap<String, ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>>> applicationIdMap = RM_CHANNELS.get(resourceId);
if (targetApplicationId == null || applicationIdMap == null || applicationIdMap.isEmpty()) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("No channel is available for resource[{}]", resourceId);
}
return null;
}
ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>> ipMap = applicationIdMap.get(targetApplicationId);
if (ipMap != null && !ipMap.isEmpty()) {
// Firstly, try to find the original channel through which the branch was registered.
// 端口 -> RpcContext
ConcurrentMap<Integer, RpcContext> portMapOnTargetIP = ipMap.get(targetIP);
/**
* 在 targetIp 上拿 Channel
*/
if (portMapOnTargetIP != null && !portMapOnTargetIP.isEmpty()) {
RpcContext exactRpcContext = portMapOnTargetIP.get(targetPort);
if (exactRpcContext != null) {
Channel channel = exactRpcContext.getChannel();
if (channel.isActive()) {
// Channel 有效,则跳过下面所有的 if 返回这个 Channel
resultChannel = channel;
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Just got exactly the one {} for {}", channel, clientId);
}
} else {
if (portMapOnTargetIP.remove(targetPort, exactRpcContext)) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Removed inactive {}", channel);
}
}
}
}
// The original channel was broken, try another one.
if (resultChannel == null) {
// 尝试当前节点上的其他端口
for (ConcurrentMap.Entry<Integer, RpcContext> portMapOnTargetIPEntry : portMapOnTargetIP.entrySet()) {
Channel channel = portMapOnTargetIPEntry.getValue().getChannel();
if (channel.isActive()) {
resultChannel = channel;
if (LOGGER.isInfoEnabled()) {
LOGGER.info(
"Choose {} on the same IP[{}] as alternative of {}", channel, targetIP, clientId);
}
break;
} else {
if (portMapOnTargetIP.remove(portMapOnTargetIPEntry.getKey(),
portMapOnTargetIPEntry.getValue())) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Removed inactive {}", channel);
}
}
}
}
}
}
/**
* 在 targetApplicationId 上拿 Channel
*/
// No channel on the app node, try another one.
if (resultChannel == null) {
for (ConcurrentMap.Entry<String, ConcurrentMap<Integer, RpcContext>> ipMapEntry : ipMap.entrySet()) {
if (ipMapEntry.getKey().equals(targetIP)) {
continue;
}
ConcurrentMap<Integer, RpcContext> portMapOnOtherIP = ipMapEntry.getValue();
if (portMapOnOtherIP == null || portMapOnOtherIP.isEmpty()) {
continue;
}
for (ConcurrentMap.Entry<Integer, RpcContext> portMapOnOtherIPEntry : portMapOnOtherIP.entrySet()) {
Channel channel = portMapOnOtherIPEntry.getValue().getChannel();
if (channel.isActive()) {
resultChannel = channel;
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Choose {} on the same application[{}] as alternative of {}", channel, targetApplicationId, clientId);
}
break;
} else {
if (portMapOnOtherIP.remove(portMapOnOtherIPEntry.getKey(), portMapOnOtherIPEntry.getValue())) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Removed inactive {}", channel);
}
}
}
}
if (resultChannel != null) {
break;
}
}
}
}
if (resultChannel == null && tryOtherApp) {
// 尝试其他 applicationId
resultChannel = tryOtherApp(applicationIdMap, targetApplicationId);
if (resultChannel == null) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("No channel is available for resource[{}] as alternative of {}", resourceId, clientId);
}
} else {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Choose {} on the same resource[{}] as alternative of {}", resultChannel, resourceId, clientId);
}
}
}
return resultChannel;
}

private static Channel tryOtherApp(ConcurrentMap<String, ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>>> applicationIdMap, String myApplicationId) {
Channel chosenChannel = null;
for (ConcurrentMap.Entry<String, ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>>> applicationIdMapEntry : applicationIdMap.entrySet()) {
if (!StringUtils.isNullOrEmpty(myApplicationId) && applicationIdMapEntry.getKey().equals(myApplicationId)) {
continue;
}
ConcurrentMap<String, ConcurrentMap<Integer, RpcContext>> targetIPMap = applicationIdMapEntry.getValue();
if (targetIPMap == null || targetIPMap.isEmpty()) {
continue;
}
for (ConcurrentMap.Entry<String, ConcurrentMap<Integer, RpcContext>> targetIPMapEntry : targetIPMap.entrySet()) {
ConcurrentMap<Integer, RpcContext> portMap = targetIPMapEntry.getValue();
if (portMap == null || portMap.isEmpty()) {
continue;
}
for (ConcurrentMap.Entry<Integer, RpcContext> portMapEntry : portMap.entrySet()) {
Channel channel = portMapEntry.getValue().getChannel();
if (channel.isActive()) {
chosenChannel = channel;
break;
} else {
if (portMap.remove(portMapEntry.getKey(), portMapEntry.getValue())) {
if (LOGGER.isInfoEnabled()) {
LOGGER.info("Removed inactive {}", channel);
}
}
}
}
if (chosenChannel != null) {
break;
}
}
if (chosenChannel != null) {
break;
}
}
return chosenChannel;
}

一图总结

最后,再以一个时序图来总结一下 Channel 的管理过程。

image-20241217222155609

Seata 如何设计协议

对于一个网络程序而言,通信协议是必不可少的,Seata 也不例外,这里我们就看看 Seata V1 版本的协议是如何实现的。

与之相关类主要有 ProtocolEncoderV1、ProtocolDecoderV1。

当然,我们前面也知道 Seata Server 启动时加入的处理器其实是 MultiProtocolDecoder,在这个类的 decode 方法中,如下:

protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
ByteBuf frame;
Object decoded;
byte version;
try {
if (isV0(in)) {
decoded = in;
version = ProtocolConstants.VERSION_0;
} else {
decoded = super.decode(ctx, in);
version = decideVersion(decoded);
}
if (decoded instanceof ByteBuf) {
frame = (ByteBuf) decoded;
// 通过 MultiProtocolDecoder 进行多版本协议识别
// 通过 version 选择对应的编解码器
ProtocolDecoder decoder = protocolDecoderMap.get(version);
ProtocolEncoder encoder = protocolEncoderMap.get(version);
try {
if (decoder == null || encoder == null) {
throw new UnsupportedOperationException("Unsupported version: " + version);
}
return decoder.decodeFrame(frame);
} finally {
if (version != ProtocolConstants.VERSION_0) {
frame.release();
}
// 将选定的编解码器加入到 pipeline,并且移除 MultiProtocolDecoder
ctx.pipeline().addLast((ChannelHandler) decoder);
ctx.pipeline().addLast((ChannelHandler) encoder);
if (channelHandlers != null) {
ctx.pipeline().addLast(channelHandlers);
}
ctx.pipeline().remove(this);
}
}
} catch (Exception exx) {
LOGGER.error("Decode frame error, cause: {}", exx.getMessage());
throw new DecodeException(exx);
}
return decoded;
}

所以,这里选择好与 version 对应的编解码器,然后加入到 pipeline 中,就会将 MultiProtocolDecoder 移除。

V1 版本协议

Seata 的协议设计是比较周全并且通用的,也是主流的解决粘包半包问题的解决方案,即消息长度 + 消息内容。

协议的格式如下:

image-20241217222155609

可以看到,包括魔数、协议版本号、长度域、头长度、报文类型、序列化算法、压缩算法、请求 id、可选的 map 扩展以及报文体。

如何进行编解码

Seata 解码器使用了 Netty 内置的 LengthFieldBasedFrameDecoder,不熟悉的可以看看。

不过编解码并不难,所以简单给出代码,不过多解释。

package org.apache.seata.core.rpc.netty.v1;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;
import org.apache.seata.core.rpc.netty.ProtocolEncoder;
import org.apache.seata.core.serializer.Serializer;
import org.apache.seata.core.compressor.Compressor;
import org.apache.seata.core.compressor.CompressorFactory;
import org.apache.seata.core.protocol.ProtocolConstants;
import org.apache.seata.core.protocol.RpcMessage;
import org.apache.seata.core.serializer.SerializerServiceLoader;
import org.apache.seata.core.serializer.SerializerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Map;

/**
* <pre>
* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
* | magic |proto| full length | head | Msg |Seria|Compr| RequestId |
* | code |versi| (head+body) | length |Type |lizer|ess | |
* +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
* | Head Map [Optional] |
* +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
* | body |
* +-----------------------------------------------------------------------------------------------+
* </pre>
* <p>
* <li>Full Length: include all data </li>
* <li>Head Length: include head data from magic code to head map. </li>
* <li>Body Length: Full Length - Head Length</li>
* </p>
*/
public class ProtocolEncoderV1 extends MessageToByteEncoder implements ProtocolEncoder {

private static final Logger LOGGER = LoggerFactory.getLogger(ProtocolEncoderV1.class);

public void encode(RpcMessage message, ByteBuf out) {
try {
ProtocolRpcMessageV1 rpcMessage = new ProtocolRpcMessageV1();
rpcMessage.rpcMsgToProtocolMsg(message);
int fullLength = ProtocolConstants.V1_HEAD_LENGTH;
int headLength = ProtocolConstants.V1_HEAD_LENGTH;
byte messageType = rpcMessage.getMessageType();
out.writeBytes(ProtocolConstants.MAGIC_CODE_BYTES);
out.writeByte(ProtocolConstants.VERSION_1);
// full Length(4B) and head length(2B) will fix in the end.
out.writerIndex(out.writerIndex() + 6); // 这里跳过 full length 和 head length 的位置,最后在补
out.writeByte(messageType);
out.writeByte(rpcMessage.getCodec());
out.writeByte(rpcMessage.getCompressor());
out.writeInt(rpcMessage.getId());
// direct write head with zero-copy
Map<String, String> headMap = rpcMessage.getHeadMap();
if (headMap != null && !headMap.isEmpty()) {
int headMapBytesLength = HeadMapSerializer.getInstance().encode(headMap, out);
headLength += headMapBytesLength;
fullLength += headMapBytesLength;
}
byte[] bodyBytes = null;
// heartbeat don't have body
if (messageType != ProtocolConstants.MSGTYPE_HEARTBEAT_REQUEST && messageType != ProtocolConstants.MSGTYPE_HEARTBEAT_RESPONSE) {
Serializer serializer = SerializerServiceLoader.load(SerializerType.getByCode(rpcMessage.getCodec()), ProtocolConstants.VERSION_1);
bodyBytes = serializer.serialize(rpcMessage.getBody());
Compressor compressor = CompressorFactory.getCompressor(rpcMessage.getCompressor());
bodyBytes = compressor.compress(bodyBytes);
fullLength += bodyBytes.length;
}
if (bodyBytes != null) {
out.writeBytes(bodyBytes);
}
// fix fullLength and headLength
int writeIndex = out.writerIndex();
// skip magic code(2B) + version(1B)
out.writerIndex(writeIndex - fullLength + 3);
out.writeInt(fullLength);
out.writeShort(headLength);
out.writerIndex(writeIndex);
} catch (Throwable e) {
LOGGER.error("Encode request error!", e);
throw e;
}
}

@Override
protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws Exception {
try {
if (msg instanceof RpcMessage) {
this.encode((RpcMessage) msg, out);
} else {
throw new UnsupportedOperationException("Not support this class:" + msg.getClass());
}
} catch (Throwable e) {
LOGGER.error("Encode request error!", e);
}
}
}
package org.apache.seata.core.rpc.netty.v1;

import java.util.List;
import java.util.Map;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import org.apache.seata.core.compressor.Compressor;
import org.apache.seata.core.compressor.CompressorFactory;
import org.apache.seata.core.exception.DecodeException;
import org.apache.seata.core.protocol.HeartbeatMessage;
import org.apache.seata.core.protocol.ProtocolConstants;
import org.apache.seata.core.protocol.RpcMessage;
import org.apache.seata.core.rpc.netty.ProtocolDecoder;
import org.apache.seata.core.serializer.Serializer;
import org.apache.seata.core.serializer.SerializerServiceLoader;
import org.apache.seata.core.serializer.SerializerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* <pre>
* 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
* +-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
* | magic |proto| full length | head | Msg |Seria|Compr| RequestId |
* | code |versi| (head+body) | length |Type |lizer|ess | |
* +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
* | Head Map [Optional] |
* +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
* | body |
* +-----------------------------------------------------------------------------------------------+
* </pre>
* <p>
* <li>Full Length: include all data </li>
* <li>Head Length: include head data from magic code to head map. </li>
* <li>Body Length: Full Length - Head Length</li>
* </p>
*/
public class ProtocolDecoderV1 extends LengthFieldBasedFrameDecoder implements ProtocolDecoder {

private static final Logger LOGGER = LoggerFactory.getLogger(ProtocolDecoderV1.class);

private final List<SerializerType> supportDeSerializerTypes;

public ProtocolDecoderV1() {
/**
* int maxFrameLength,
* int lengthFieldOffset, 魔术 2B、版本号 1B 所以长度偏移 3B
* int lengthFieldLength, FullLength is int(4B). so values is 4
* int lengthAdjustment, FullLength include all data and read 7 bytes before, so the left length is (FullLength-7). so values is -7
* int initialBytesToStrip we will check magic code and version self, so do not strip any bytes. so values is 0
*/
super(ProtocolConstants.MAX_FRAME_LENGTH, 3, 4, -7, 0);
supportDeSerializerTypes = SerializerServiceLoader.getSupportedSerializers();
if (supportDeSerializerTypes.isEmpty()) {
throw new IllegalArgumentException("No serializer found");
}
}

@Override
public RpcMessage decodeFrame(ByteBuf frame) {
byte b0 = frame.readByte();
byte b1 = frame.readByte();
if (ProtocolConstants.MAGIC_CODE_BYTES[0] != b0 || ProtocolConstants.MAGIC_CODE_BYTES[1] != b1) {
throw new IllegalArgumentException("Unknown magic code: " + b0 + ", " + b1);
}
byte version = frame.readByte();
int fullLength = frame.readInt();
short headLength = frame.readShort();
byte messageType = frame.readByte();
byte codecType = frame.readByte();
byte compressorType = frame.readByte();
int requestId = frame.readInt();
ProtocolRpcMessageV1 rpcMessage = new ProtocolRpcMessageV1();
rpcMessage.setCodec(codecType);
rpcMessage.setId(requestId);
rpcMessage.setCompressor(compressorType);
rpcMessage.setMessageType(messageType);
// direct read head with zero-copy
int headMapLength = headLength - ProtocolConstants.V1_HEAD_LENGTH;
if (headMapLength > 0) {
Map<String, String> map = HeadMapSerializer.getInstance().decode(frame, headMapLength);
rpcMessage.getHeadMap().putAll(map);
}
// read body
if (messageType == ProtocolConstants.MSGTYPE_HEARTBEAT_REQUEST) {
rpcMessage.setBody(HeartbeatMessage.PING);
} else if (messageType == ProtocolConstants.MSGTYPE_HEARTBEAT_RESPONSE) {
rpcMessage.setBody(HeartbeatMessage.PONG);
} else {
int bodyLength = fullLength - headLength;
if (bodyLength > 0) {
byte[] bs = new byte[bodyLength];
frame.readBytes(bs);
Compressor compressor = CompressorFactory.getCompressor(compressorType);
bs = compressor.decompress(bs);
SerializerType protocolType = SerializerType.getByCode(rpcMessage.getCodec());
if (this.supportDeSerializerTypes.contains(protocolType)) {
Serializer serializer = SerializerServiceLoader.load(protocolType, ProtocolConstants.VERSION_1);
rpcMessage.setBody(serializer.deserialize(bs));
} else {
throw new IllegalArgumentException("SerializerType not match");
}
}
}
return rpcMessage.protocolMsgToRpcMsg();
}

@Override
protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception {
Object decoded;
try {
decoded = super.decode(ctx, in);
if (decoded instanceof ByteBuf) {
ByteBuf frame = (ByteBuf) decoded;
try {
return decodeFrame(frame);
} finally {
frame.release();
}
}
} catch (Exception exx) {
LOGGER.error("Decode frame error, cause: {}", exx.getMessage());
throw new DecodeException(exx);
}
return decoded;
}
}

总结

就目前看来,Seata 的网络通信实现的是比较容易看懂的,不过,这篇文章的分析也仅仅只是浮于表面,对深层次的更加重要的代码健壮性、异常处理、优雅关闭等问题都没有聊到,看后面有新的理解再分析分析。

原文链接