diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java index b14b7fa77c31f..84674a95db220 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/DbzCdcEngineRunner.java @@ -16,7 +16,7 @@ import com.risingwave.connector.source.common.DbzConnectorConfig; import com.risingwave.connector.source.common.DbzSourceUtils; -import com.risingwave.java.binding.Binding; +import com.risingwave.java.binding.CdcSourceChannel; import com.risingwave.proto.ConnectorServiceProto.GetEventStreamResponse; import io.debezium.config.CommonConnectorConfig; import io.grpc.stub.StreamObserver; @@ -69,7 +69,7 @@ public static DbzCdcEngineRunner newCdcEngineRunner( return runner; } - public static DbzCdcEngineRunner create(DbzConnectorConfig config, long channelPtr) { + public static DbzCdcEngineRunner create(DbzConnectorConfig config, CdcSourceChannel channel) { DbzCdcEngineRunner runner = new DbzCdcEngineRunner(config); try { var sourceId = config.getSourceId(); @@ -90,8 +90,7 @@ public static DbzCdcEngineRunner create(DbzConnectorConfig config, long channelP (error != null && error.getMessage() != null ? error.getMessage() : message); - if (!Binding.sendCdcSourceErrorToChannel( - channelPtr, errorMsg)) { + if (!channel.sendError(errorMsg)) { LOG.warn( "engine#{} unable to send error message: {}", sourceId, diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java index 7a8c513d963d9..dbc5ec6ca98c2 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/source/core/JniDbzSourceHandler.java @@ -22,7 +22,7 @@ import com.risingwave.connector.source.common.CdcConnectorException; import com.risingwave.connector.source.common.DbzConnectorConfig; import com.risingwave.connector.source.common.DbzSourceUtils; -import com.risingwave.java.binding.Binding; +import com.risingwave.java.binding.CdcSourceChannel; import com.risingwave.metrics.ConnectorNodeMetrics; import com.risingwave.proto.ConnectorServiceProto; import com.risingwave.proto.ConnectorServiceProto.GetEventStreamResponse; @@ -40,10 +40,12 @@ public class JniDbzSourceHandler { private final DbzConnectorConfig config; private final DbzCdcEngineRunner runner; + private final CdcSourceChannel channel; - public JniDbzSourceHandler(DbzConnectorConfig config, long channelPtr) { + public JniDbzSourceHandler(DbzConnectorConfig config, CdcSourceChannel channel) { this.config = config; - this.runner = DbzCdcEngineRunner.create(config, channelPtr); + this.runner = DbzCdcEngineRunner.create(config, channel); + this.channel = channel; if (runner == null) { throw new CdcConnectorException("Failed to create engine runner"); @@ -52,6 +54,9 @@ public JniDbzSourceHandler(DbzConnectorConfig config, long channelPtr) { public static void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr) throws Exception { + + var channel = CdcSourceChannel.fromOwnedPointer(channelPtr); + var request = ConnectorServiceProto.GetEventStreamRequest.parseFrom(getEventStreamRequestBytes); // userProps extracted from request, underlying implementation is UnmodifiableMap @@ -72,10 +77,8 @@ public static void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long mutableUserProps, request.getSnapshotDone(), isCdcSourceJob); - JniDbzSourceHandler handler = new JniDbzSourceHandler(config, channelPtr); - // register handler to the registry - JniDbzSourceRegistry.register(config.getSourceId(), handler); - handler.start(channelPtr); + JniDbzSourceHandler handler = new JniDbzSourceHandler(config, channel); + handler.start(); } public void commitOffset(String encodedOffset) throws InterruptedException { @@ -96,12 +99,14 @@ public void commitOffset(String encodedOffset) throws InterruptedException { } } - public void start(long channelPtr) { - + public void start() { try { + // register handler to the registry + JniDbzSourceRegistry.register(config.getSourceId(), this); + // Start the engine var startOk = runner.start(); - if (!sendHandshakeMessage(runner, channelPtr, startOk)) { + if (!sendHandshakeMessage(runner, channel, startOk)) { LOG.error( "Failed to send handshake message to channel. sourceId={}", config.getSourceId()); @@ -125,14 +130,16 @@ public void start(long channelPtr) { "Engine#{}: emit one chunk {} events to network ", config.getSourceId(), resp.getEventsCount()); - success = Binding.sendCdcSourceMsgToChannel(channelPtr, resp.toByteArray()); + success = channel.send(resp.toByteArray()); } else { // If resp is null means just check whether channel is closed. - success = Binding.sendCdcSourceMsgToChannel(channelPtr, null); + success = channel.send(null); } + // When user drops the connector, the channel rx will be dropped and we fail to send + // the message. We should stop the engine in this case. if (!success) { LOG.info( - "Engine#{}: JNI sender broken detected, stop the engine", + "Engine#{}: JNI receiver closed, stop the engine", config.getSourceId()); runner.stop(); return; @@ -145,14 +152,14 @@ public void start(long channelPtr) { } catch (Exception e) { LOG.warn("Failed to stop Engine#{}", config.getSourceId(), e); } + } finally { + // remove the handler from registry + JniDbzSourceRegistry.unregister(config.getSourceId()); } - - // remove the handler from registry - JniDbzSourceRegistry.unregister(config.getSourceId()); } private boolean sendHandshakeMessage( - DbzCdcEngineRunner runner, long channelPtr, boolean startOk) throws Exception { + DbzCdcEngineRunner runner, CdcSourceChannel channel, boolean startOk) throws Exception { // send a handshake message to notify the Source executor // if the handshake is not ok, the split reader will return error to source actor var controlInfo = @@ -163,7 +170,7 @@ private boolean sendHandshakeMessage( .setSourceId(config.getSourceId()) .setControl(controlInfo) .build(); - var success = Binding.sendCdcSourceMsgToChannel(channelPtr, handshakeMsg.toByteArray()); + var success = channel.send(handshakeMsg.toByteArray()); if (!success) { LOG.info( "Engine#{}: JNI sender broken detected, stop the engine", config.getSourceId()); diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index 42ddcf2b28d4d..ab9ee03e2f445 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -94,6 +94,8 @@ public static native void tracingSlf4jEvent( public static native boolean sendCdcSourceErrorToChannel(long channelPtr, String errorMsg); + public static native void cdcSourceSenderClose(long channelPtr); + public static native com.risingwave.java.binding.JniSinkWriterStreamRequest recvSinkWriterRequestFromChannel(long channelPtr); diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/CdcSourceChannel.java b/java/java-binding/src/main/java/com/risingwave/java/binding/CdcSourceChannel.java new file mode 100644 index 0000000000000..83cebfc9b5b1b --- /dev/null +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/CdcSourceChannel.java @@ -0,0 +1,44 @@ +// Copyright 2025 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package com.risingwave.java.binding; + +import java.lang.ref.Cleaner; + +public class CdcSourceChannel { + private final long pointer; + + private static final Cleaner cleaner = Cleaner.create(); + + CdcSourceChannel(long pointer) { + this.pointer = pointer; + cleaner.register( + this, + () -> { + Binding.cdcSourceSenderClose(pointer); + }); + } + + public static CdcSourceChannel fromOwnedPointer(long pointer) { + return new CdcSourceChannel(pointer); + } + + public boolean send(byte[] msg) { + return Binding.sendCdcSourceMsgToChannel(pointer, msg); + } + + public boolean sendError(String errorMsg) { + return Binding.sendCdcSourceErrorToChannel(pointer, errorMsg); + } +} diff --git a/src/connector/src/source/cdc/source/reader.rs b/src/connector/src/source/cdc/source/reader.rs index 46f1c79ceedb6..899a7d3a55abd 100644 --- a/src/connector/src/source/cdc/source/reader.rs +++ b/src/connector/src/source/cdc/source/reader.rs @@ -23,7 +23,7 @@ use risingwave_common::bail; use risingwave_common::metrics::GLOBAL_ERROR_METRICS; use risingwave_common::util::addr::HostAddr; use risingwave_jni_core::jvm_runtime::{execute_with_jni_env, JVM}; -use risingwave_jni_core::{call_static_method, JniReceiverType, JniSenderType}; +use risingwave_jni_core::{call_static_method, JniReceiverType, OwnedPointer}; use risingwave_pb::connector_service::{GetEventStreamRequest, GetEventStreamResponse}; use thiserror_ext::AsReport; use tokio::sync::mpsc; @@ -98,7 +98,7 @@ impl<T: CdcSourceTypeTrait> SplitReader for CdcSplitReader<T> { let source_id = split.split_id() as u64; let source_type = conn_props.get_source_type_pb(); - let (mut tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); + let (tx, mut rx) = mpsc::channel(DEFAULT_CHANNEL_SIZE); let jvm = JVM.get_or_init()?; let get_event_stream_request = GetEventStreamRequest { @@ -129,12 +129,16 @@ impl<T: CdcSourceTypeTrait> SplitReader for CdcSplitReader<T> { } }; + // `runJniDbzSourceThread` will take ownership of `tx`, and release it later in + // `Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose` via ref cleaner. + let tx: OwnedPointer<_> = tx.into(); + let result = call_static_method!( env, {com.risingwave.connector.source.core.JniDbzSourceHandler}, {void runJniDbzSourceThread(byte[] getEventStreamRequestBytes, long channelPtr)}, &get_event_stream_request_bytes, - &mut tx as *mut JniSenderType<GetEventStreamResponse> + tx.into_pointer() ); match result { diff --git a/src/jni_core/src/lib.rs b/src/jni_core/src/lib.rs index 5cfd8c66478c2..b2e0d58a2686e 100644 --- a/src/jni_core/src/lib.rs +++ b/src/jni_core/src/lib.rs @@ -152,21 +152,36 @@ impl<T> From<T> for Pointer<'static, T> { impl<'a, T> Pointer<'a, T> { fn as_ref(&self) -> &'a T { - debug_assert!(self.pointer != 0); + assert!(self.pointer != 0); unsafe { &*(self.pointer as *const T) } } fn as_mut(&mut self) -> &'a mut T { - debug_assert!(self.pointer != 0); + assert!(self.pointer != 0); unsafe { &mut *(self.pointer as *mut T) } } } +/// A pointer that owns the object it points to. +/// +/// Note that dropping an `OwnedPointer` does not release the object. +/// Instead, you should call [`OwnedPointer::release`] manually. pub type OwnedPointer<T> = Pointer<'static, T>; impl<T> OwnedPointer<T> { - fn drop(self) { - debug_assert!(self.pointer != 0); + /// Consume `self` and return the pointer value. Used for passing to JNI. + pub fn into_pointer(self) -> jlong { + self.pointer + } + + /// Release the object behind the pointer. + fn release(self) { + tracing::debug!( + type_name = std::any::type_name::<T>(), + address = %format_args!("{:x}", self.pointer), + "release jni OwnedPointer" + ); + assert!(self.pointer != 0); unsafe { drop(Box::from_raw(self.pointer as *mut T)) } } } @@ -389,7 +404,7 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_iteratorClose<'a>( _env: EnvParam<'a>, pointer: OwnedPointer<JavaBindingIterator<'a>>, ) { - pointer.drop() + pointer.release() } #[no_mangle] @@ -419,7 +434,7 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkClose( _env: EnvParam<'_>, chunk: OwnedPointer<StreamChunk>, ) { - chunk.drop() + chunk.release() } #[no_mangle] @@ -1052,6 +1067,14 @@ extern "system" fn Java_com_risingwave_java_binding_Binding_sendCdcSourceErrorTo }) } +#[no_mangle] +extern "system" fn Java_com_risingwave_java_binding_Binding_cdcSourceSenderClose( + _env: EnvParam<'_>, + channel: OwnedPointer<JniSenderType<GetEventStreamResponse>>, +) { + channel.release(); +} + pub enum JniSinkWriterStreamRequest { PbRequest(SinkWriterStreamRequest), Chunk { diff --git a/src/jni_core/src/macros.rs b/src/jni_core/src/macros.rs index ecab1426ee395..08286d5985259 100644 --- a/src/jni_core/src/macros.rs +++ b/src/jni_core/src/macros.rs @@ -502,6 +502,8 @@ macro_rules! for_all_plain_native_methods { public static native boolean sendCdcSourceErrorToChannel(long channelPtr, String errorMsg); + public static native void cdcSourceSenderClose(long channelPtr); + public static native com.risingwave.java.binding.JniSinkWriterStreamRequest recvSinkWriterRequestFromChannel(long channelPtr); @@ -930,6 +932,7 @@ mod tests { iteratorGetArrayValue (JILjava/lang/Class;)Ljava/lang/Object;, sendCdcSourceMsgToChannel (J[B)Z, sendCdcSourceErrorToChannel (JLjava/lang/String;)Z, + cdcSourceSenderClose (J)V, recvSinkWriterRequestFromChannel (J)Lcom/risingwave/java/binding/JniSinkWriterStreamRequest;, sendSinkWriterResponseToChannel (J[B)Z, sendSinkWriterErrorToChannel (JLjava/lang/String;)Z,