Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: message_as_jsonb to handle circle dep in protobuf #19935

Merged
merged 22 commits into from
Dec 31, 2024
54 changes: 54 additions & 0 deletions e2e_test/source_inline/kafka/protobuf/opentelemetry.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
control substitution on

system ok
rpk registry schema create "opentelemetry_common.proto" --schema "/risingwave/src/connector/codec/tests/test_data/opentelemetry_common.proto"

system ok
rpk registry schema create "opentelemetry_test-value" --schema "/dev/stdin" --references opentelemetry_common.proto:opentelemetry_common.proto:1 --type protobuf << EOF
syntax = "proto3";

package opentelemetry_test;

import "opentelemetry_common.proto";

message OTLPTestMessage {
opentelemetry.proto.common.v1.AnyValue any_value = 1;
opentelemetry.proto.common.v1.KeyValueList key_value_list = 2;
opentelemetry.proto.common.v1.InstrumentationScope instrumentation_scope = 3;
}
EOF

system ok
echo '{"any_value":{"string_value":"example"},"key_value_list":{"values":[{"key":"key1","value":{"string_value":"value1"}},{"key":"key2","value":{"int_value":42}}]},"instrumentation_scope":{"name":"test-scope","version":"1.0"}}' | rpk topic produce "opentelemetry_test" --schema-id=topic --schema-type="opentelemetry_test.OTLPTestMessage" --allow-auto-topic-creation

statement ok
create table opentelemetry_test with ( ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, topic = 'opentelemetry_test' ) format plain encode protobuf ( schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}', message = 'opentelemetry_test.OTLPTestMessage', messages_as_jsonb = 'opentelemetry.proto.common.v1.ArrayValue,opentelemetry.proto.common.v1.KeyValueList,opentelemetry.proto.common.v1.AnyValue');

statement ok
flush;

sleep 1s

query T
select count(*) from opentelemetry_test;
----
1

query TTT
select any_value, key_value_list, instrumentation_scope from opentelemetry_test;
----
{"stringValue": "example"} {"values": [{"key": "key1", "value": {"stringValue": "value1"}}, {"key": "key2", "value": {"intValue": "42"}}]} (test-scope,1.0,{},0)

# ==== clean up ====

statement ok
drop table opentelemetry_test;

system ok
rpk topic delete opentelemetry_test;

system ok
rpk registry subject delete "opentelemetry_test-value"

system ok
rpk registry subject delete "opentelemetry_common.proto"
45 changes: 45 additions & 0 deletions e2e_test/source_inline/kafka/protobuf/recursive_overflow.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
control substitution on

system ok
rpk registry schema create "recursive_complex-value" --schema "/dev/stdin" --type protobuf << EOF
syntax = "proto3";

// a recursive complex type can cause stack overflow in the frontend when inferring the schema

package recursive_complex;

message AnyValue {
oneof value {
string string_value = 1;
int32 int_value = 2;
double double_value = 3;
bool bool_value = 4;
ArrayValue array_value = 5;
}
}

message ArrayValue {
AnyValue value1 = 1;
AnyValue value2 = 2;
ArrayValue array_value = 3;
}

EOF

system ok
echo '{"array_value":{"value1":{"string_value":"This is a string value"},"value2":{"int_value":42},"array_value":{"value1":{"double_value":3.14159},"value2":{"bool_value":true},"array_value":{"value1":{"string_value":"Deeply nested string"},"value2":{"int_value":100}}}}}' | rpk topic produce "recursive_complex" --schema-id=topic --schema-type="recursive_complex.AnyValue" --allow-auto-topic-creation

# the test just make sure the table can finish create process
statement ok
create table recursive_complex with ( ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, topic = 'recursive_complex' ) format plain encode protobuf ( schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}', message = 'recursive_complex.AnyValue', messages_as_jsonb = 'recursive_complex.AnyValue,recursive_complex.ArrayValue');

# ==== clean up ====

statement ok
drop table recursive_complex;

system ok
rpk topic delete recursive_complex;

system ok
rpk registry subject delete "recursive_complex-value"
26 changes: 17 additions & 9 deletions src/connector/codec/src/decoder/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

pub mod parser;
use std::borrow::Cow;
use std::collections::HashSet;
use std::sync::LazyLock;

use parser::from_protobuf_value;
Expand All @@ -24,13 +25,17 @@ use thiserror_ext::AsReport;

use super::{uncategorized, Access, AccessResult};

pub struct ProtobufAccess {
pub struct ProtobufAccess<'a> {
message: DynamicMessage,
messages_as_jsonb: &'a HashSet<String>,
}

impl ProtobufAccess {
pub fn new(message: DynamicMessage) -> Self {
Self { message }
impl<'a> ProtobufAccess<'a> {
pub fn new(message: DynamicMessage, messages_as_jsonb: &'a HashSet<String>) -> Self {
Self {
message,
messages_as_jsonb,
}
}

#[cfg(test)]
Expand All @@ -39,7 +44,7 @@ impl ProtobufAccess {
}
}

impl Access for ProtobufAccess {
impl Access for ProtobufAccess<'_> {
fn access<'a>(&'a self, path: &[&str], type_expected: &DataType) -> AccessResult<DatumCow<'a>> {
debug_assert_eq!(1, path.len());
let field_desc = self
Expand All @@ -56,12 +61,15 @@ impl Access for ProtobufAccess {
})?;

match self.message.get_field(&field_desc) {
Cow::Borrowed(value) => from_protobuf_value(&field_desc, value, type_expected),
Cow::Borrowed(value) => {
from_protobuf_value(&field_desc, value, type_expected, self.messages_as_jsonb)
}

// `Owned` variant occurs only if there's no such field and the default value is returned.
Cow::Owned(value) => from_protobuf_value(&field_desc, &value, type_expected)
// enforce `Owned` variant to avoid returning a reference to a temporary value
.map(|d| d.to_owned_datum().into()),
Cow::Owned(value) => {
from_protobuf_value(&field_desc, &value, type_expected, self.messages_as_jsonb)
.map(|d| d.to_owned_datum().into())
}
}
}
}
68 changes: 55 additions & 13 deletions src/connector/codec/src/decoder/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;

use anyhow::Context;
use itertools::Itertools;
use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
Expand All @@ -26,14 +28,22 @@ use thiserror_ext::Macro;

use crate::decoder::{uncategorized, AccessError, AccessResult};

pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";

pub fn pb_schema_to_column_descs(
message_descriptor: &MessageDescriptor,
messages_as_jsonb: &HashSet<String>,
) -> anyhow::Result<Vec<ColumnDesc>> {
let mut columns = Vec::with_capacity(message_descriptor.fields().len());
let mut index = 0;
let mut parse_trace: Vec<String> = vec![];
for field in message_descriptor.fields() {
columns.push(pb_field_to_col_desc(&field, &mut index, &mut parse_trace)?);
columns.push(pb_field_to_col_desc(
&field,
&mut index,
&mut parse_trace,
messages_as_jsonb,
)?);
}

Ok(columns)
Expand All @@ -44,15 +54,18 @@ fn pb_field_to_col_desc(
field_descriptor: &FieldDescriptor,
index: &mut i32,
parse_trace: &mut Vec<String>,
messages_as_jsonb: &HashSet<String>,
) -> anyhow::Result<ColumnDesc> {
let field_type = protobuf_type_mapping(field_descriptor, parse_trace)
let field_type = protobuf_type_mapping(field_descriptor, parse_trace, messages_as_jsonb)
.context("failed to map protobuf type")?;
if let Kind::Message(m) = field_descriptor.kind() {
if let Kind::Message(m) = field_descriptor.kind()
&& !messages_as_jsonb.contains(m.full_name())
{
let field_descs = if let DataType::List { .. } = field_type {
vec![]
} else {
m.fields()
.map(|f| pb_field_to_col_desc(&f, index, parse_trace))
.map(|f| pb_field_to_col_desc(&f, index, parse_trace, messages_as_jsonb))
.try_collect()?
};
*index += 1;
Expand Down Expand Up @@ -92,10 +105,12 @@ fn detect_loop_and_push(
let identifier = format!("{}({})", fd.name(), fd.full_name());
if trace.iter().any(|s| s == identifier.as_str()) {
bail_protobuf_type_error!(
"circular reference detected: {}, conflict with {}, kind {:?}",
"circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
trace.iter().format("->"),
identifier,
fd.kind(),
fd.kind(),
PROTOBUF_MESSAGES_AS_JSONB,
);
}
trace.push(identifier);
Expand All @@ -106,6 +121,7 @@ pub fn from_protobuf_value<'a>(
field_desc: &FieldDescriptor,
value: &'a Value,
type_expected: &DataType,
messages_as_jsonb: &'a HashSet<String>,
) -> AccessResult<DatumCow<'a>> {
let kind = field_desc.kind();

Expand Down Expand Up @@ -136,7 +152,7 @@ pub fn from_protobuf_value<'a>(
ScalarImpl::Utf8(enum_symbol.name().into())
}
Value::Message(dyn_msg) => {
if dyn_msg.descriptor().full_name() == "google.protobuf.Any" {
if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
ScalarImpl::Jsonb(JsonbVal::from(
serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
))
Expand All @@ -159,8 +175,13 @@ pub fn from_protobuf_value<'a>(
};
let value = dyn_msg.get_field(&field_desc);
rw_values.push(
from_protobuf_value(&field_desc, &value, expected_field_type)?
.to_owned_datum(),
from_protobuf_value(
&field_desc,
&value,
expected_field_type,
messages_as_jsonb,
)?
.to_owned_datum(),
);
}
ScalarImpl::Struct(StructValue::new(rw_values))
Expand All @@ -176,7 +197,12 @@ pub fn from_protobuf_value<'a>(
};
let mut builder = element_type.create_array_builder(values.len());
for value in values {
builder.append(from_protobuf_value(field_desc, value, element_type)?);
builder.append(from_protobuf_value(
field_desc,
value,
element_type,
messages_as_jsonb,
)?);
}
ScalarImpl::List(ListValue::new(builder.finish()))
}
Expand Down Expand Up @@ -209,11 +235,13 @@ pub fn from_protobuf_value<'a>(
&map_desc.map_entry_key_field(),
&key.clone().into(),
map_type.key(),
messages_as_jsonb,
)?);
value_builder.append(from_protobuf_value(
&map_desc.map_entry_value_field(),
value,
map_type.value(),
messages_as_jsonb,
)?);
}
let keys = key_builder.finish();
Expand All @@ -231,6 +259,7 @@ pub fn from_protobuf_value<'a>(
fn protobuf_type_mapping(
field_descriptor: &FieldDescriptor,
parse_trace: &mut Vec<String>,
messages_as_jsonb: &HashSet<String>,
) -> std::result::Result<DataType, ProtobufTypeError> {
detect_loop_and_push(parse_trace, field_descriptor)?;
let mut t = match field_descriptor.kind() {
Expand All @@ -245,20 +274,33 @@ fn protobuf_type_mapping(
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => {
if m.full_name() == "google.protobuf.Any" {
if messages_as_jsonb.contains(m.full_name()) {
// Well-Known Types are identified by their full name
DataType::Jsonb
} else if m.is_map_entry() {
// Map is equivalent to `repeated MapFieldEntry map_field = N;`
debug_assert!(field_descriptor.is_map());
let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?;
let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?;
let key = protobuf_type_mapping(
&m.map_entry_key_field(),
parse_trace,
messages_as_jsonb,
)?;
let value = protobuf_type_mapping(
&m.map_entry_value_field(),
parse_trace,
messages_as_jsonb,
)?;
_ = parse_trace.pop();
return Ok(DataType::Map(MapType::from_kv(key, value)));
} else {
let fields = m
.fields()
.map(|f| Ok((f.name().to_owned(), protobuf_type_mapping(&f, parse_trace)?)))
.map(|f| {
Ok((
f.name().to_owned(),
protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
))
})
.try_collect::<_, Vec<_>, _>()?;
StructType::new(fields).into()
}
Expand Down
Loading
Loading