Skip to content

Commit

Permalink
Add set_deserializer method to Python KafkaSourceBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
fqaiser94 committed Jan 13, 2025
1 parent f6a077a commit 1bd40dc
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
52 changes: 51 additions & 1 deletion flink-python/pyflink/datastream/connectors/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
'KafkaOffsetResetStrategy',
'KafkaRecordSerializationSchema',
'KafkaRecordSerializationSchemaBuilder',
'KafkaTopicSelector'
'KafkaTopicSelector',
'KafkaRecordDeserializationSchema',
'SimpleStringValueKafkaRecordDeserializationSchema'
]


Expand Down Expand Up @@ -353,6 +355,38 @@ def ignore_failures_after_transaction_timeout(self) -> 'FlinkKafkaProducer':

# ---- KafkaSource ----

class KafkaRecordDeserializationSchema:
"""
Base class for KafkaRecordDeserializationSchema. The kafka record deserialization schema
describes how to turn the byte messages delivered by Apache Kafka into data types (Java/
Scala objects) that are processed by Flink.
In addition, the KafkaRecordDeserializationSchema describes the produced type which lets
Flink create internal serializers and structures to handle the type.
"""
def __init__(self, j_kafka_record_deserialization_schema=None):
self.j_kafka_record_deserialization_schema = j_kafka_record_deserialization_schema


class SimpleStringValueKafkaRecordDeserializationSchema(KafkaRecordDeserializationSchema):
"""
Very simple deserialization schema for strings values. By default, the deserializer uses
'UTF-8' for byte to string conversion.
"""

def __init__(self, charset: str = 'UTF-8'):
gate_way = get_gateway()
j_char_set = gate_way.jvm.java.nio.charset.Charset.forName(charset)
j_simple_string_serialization_schema = gate_way.jvm \
.org.apache.flink.api.common.serialization.SimpleStringSchema(j_char_set)
j_kafka_record_deserialization_schema = gate_way.jvm \
.org.apache.flink.connector.kafka.source.reader.deserializer \
.KafkaRecordDeserializationSchema.valueOnly(j_simple_string_serialization_schema)
KafkaRecordDeserializationSchema.__init__(
self, j_kafka_record_deserialization_schema=j_kafka_record_deserialization_schema)


# ---- KafkaSource ----

class KafkaSource(Source):
"""
Expand Down Expand Up @@ -611,6 +645,22 @@ def set_value_only_deserializer(self, deserialization_schema: DeserializationSch
self._j_builder.setValueOnlyDeserializer(deserialization_schema._j_deserialization_schema)
return self

def set_deserializer(
self,
kafka_record_deserialization_schema: KafkaRecordDeserializationSchema
) -> 'KafkaSourceBuilder':
"""
Sets the :class:`~pyflink.datastream.connectors.kafka.KafkaRecordDeserializationSchema`
for deserializing Kafka ConsumerRecords.
:param kafka_record_deserialization_schema: the :class:`KafkaRecordDeserializationSchema`
to use for deserialization.
:return: this KafkaSourceBuilder.
"""
self._j_builder.setDeserializer(
kafka_record_deserialization_schema.j_kafka_record_deserialization_schema)
return self

def set_client_id_prefix(self, prefix: str) -> 'KafkaSourceBuilder':
"""
Sets the client id prefix of this KafkaSource.
Expand Down
19 changes: 18 additions & 1 deletion flink-python/pyflink/datastream/connectors/tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
from pyflink.datastream.connectors.base import DeliveryGuarantee
from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \
KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \
FlinkKafkaProducer, FlinkKafkaConsumer
FlinkKafkaProducer, FlinkKafkaConsumer, KafkaRecordDeserializationSchema, \
SimpleStringValueKafkaRecordDeserializationSchema
from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema
from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema
from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema
Expand Down Expand Up @@ -332,6 +333,22 @@ def _check(schema: DeserializationSchema, class_name: str):
'org.apache.flink.formats.avro.AvroRowDeserializationSchema'
)

def test_set_kafka_record_deserialization_schema(self):
def _check(schema: KafkaRecordDeserializationSchema, java_class_name: str):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_deserializer(schema) \
.build()
kafka_record_deserialization_schema = get_field_value(source.get_java_function(),
'deserializationSchema')
self.assertEqual(kafka_record_deserialization_schema.getClass().getCanonicalName(),
java_class_name)

_check(SimpleStringValueKafkaRecordDeserializationSchema(),
'org.apache.flink.connector.kafka.source.reader.deserializer.'
'KafkaValueOnlyDeserializationSchemaWrapper')

def _check_reader_handled_offsets_initializer(self,
source: KafkaSource,
offset: int,
Expand Down

0 comments on commit 1bd40dc

Please sign in to comment.