Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: message_as_jsonb to handle circle dep in protobuf #19935

Merged
merged 22 commits into from
Dec 31, 2024
50 changes: 50 additions & 0 deletions e2e_test/source_inline/kafka/protobuf/opentelemetry.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
system ok
rpk registry schema create "opentelemetry_common.proto" --schema "/risingwave/src/connector/codec/tests/test_data/opentelemetry_common.proto"

system ok
rpk registry schema create "opentelemetry_test-value" --schema "/risingwave/src/connector/codec/tests/test_data/opentelemetry_test.proto" --references opentelemetry_common.proto:opentelemetry_common.proto:1

system ok
echo '{"any_value":{"string_value":"example"},"key_value_list":{"values":[{"key":"key1","value":{"string_value":"value1"}},{"key":"key2","value":{"int_value":42}}]},"instrumentation_scope":{"name":"test-scope","version":"1.0"}}' | rpk topic produce "opentelemetry_test" --schema-id=topic --schema-type="opentelemetry_test.OTLPTestMessage" --allow-auto-topic-creation

statement ok
create table opentelemetry_test with ( ${RISEDEV_KAFKA_WITH_OPTIONS_COMMON}, topic = 'opentelemetry_test' ) format plain encode protobuf ( schema.registry = '${RISEDEV_SCHEMA_REGISTRY_URL}, message = 'opentelemetry_test.OTLPTestMessage', messages_as_jsonb = 'opentelemetry.proto.common.v1.ArrayValue,opentelemetry.proto.common.v1.KeyValueList,opentelemetry.proto.common.v1.AnyValue');
tabVersion marked this conversation as resolved.
Show resolved Hide resolved

statement ok
flush;

sleep 1s

query T
select count(*) from opentelemetry_test;
----
1

query TTIITTT
select (any_value).string_value, (any_value).bool_value, (any_value).int_value, (any_value).double_value, (any_value).array_value, (any_value).kvlist_value, (any_value).bytes_value from opentelemetry_test ;
----
example f 0 0 {} {} \x

query T
select key_value_list from opentelemetry_test;
----
{"values": [{"key": "key1", "value": {"stringValue": "value1"}}, {"key": "key2", "value": {"intValue": "42"}}]}

query TTTI
select (instrumentation_scope).name, (instrumentation_scope).version, (instrumentation_scope).attributes, (instrumentation_scope).dropped_attributes_count from opentelemetry_test;
----
test-scope 1.0 {} 0

# ==== clean up ====

statement ok
drop table opentelemetry_test;

system ok
rpk topic delete opentelemetry_test;

system ok
rpk registry subject delete "opentelemetry_test-value"

system ok
rpk registry subject delete "opentelemetry_common.proto"
26 changes: 17 additions & 9 deletions src/connector/codec/src/decoder/protobuf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

pub mod parser;
use std::borrow::Cow;
use std::collections::HashSet;
use std::sync::LazyLock;

use parser::from_protobuf_value;
Expand All @@ -24,13 +25,17 @@ use thiserror_ext::AsReport;

use super::{uncategorized, Access, AccessResult};

pub struct ProtobufAccess {
pub struct ProtobufAccess<'a> {
message: DynamicMessage,
messages_as_jsonb: &'a HashSet<String>,
}

impl ProtobufAccess {
pub fn new(message: DynamicMessage) -> Self {
Self { message }
impl<'a> ProtobufAccess<'a> {
pub fn new(message: DynamicMessage, messages_as_jsonb: &'a HashSet<String>) -> Self {
Self {
message,
messages_as_jsonb,
}
}

#[cfg(test)]
Expand All @@ -39,7 +44,7 @@ impl ProtobufAccess {
}
}

impl Access for ProtobufAccess {
impl Access for ProtobufAccess<'_> {
fn access<'a>(&'a self, path: &[&str], type_expected: &DataType) -> AccessResult<DatumCow<'a>> {
debug_assert_eq!(1, path.len());
let field_desc = self
Expand All @@ -56,12 +61,15 @@ impl Access for ProtobufAccess {
})?;

match self.message.get_field(&field_desc) {
Cow::Borrowed(value) => from_protobuf_value(&field_desc, value, type_expected),
Cow::Borrowed(value) => {
from_protobuf_value(&field_desc, value, type_expected, self.messages_as_jsonb)
}

// `Owned` variant occurs only if there's no such field and the default value is returned.
Cow::Owned(value) => from_protobuf_value(&field_desc, &value, type_expected)
// enforce `Owned` variant to avoid returning a reference to a temporary value
.map(|d| d.to_owned_datum().into()),
Cow::Owned(value) => {
from_protobuf_value(&field_desc, &value, type_expected, self.messages_as_jsonb)
.map(|d| d.to_owned_datum().into())
}
}
}
}
64 changes: 52 additions & 12 deletions src/connector/codec/src/decoder/protobuf/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;

use anyhow::Context;
use itertools::Itertools;
use prost_reflect::{Cardinality, FieldDescriptor, Kind, MessageDescriptor, ReflectMessage, Value};
Expand All @@ -26,14 +28,22 @@ use thiserror_ext::Macro;

use crate::decoder::{uncategorized, AccessError, AccessResult};

pub const PROTOBUF_MESSAGES_AS_JSONB: &str = "messages_as_jsonb";

pub fn pb_schema_to_column_descs(
message_descriptor: &MessageDescriptor,
messages_as_jsonb: &HashSet<String>,
) -> anyhow::Result<Vec<ColumnDesc>> {
let mut columns = Vec::with_capacity(message_descriptor.fields().len());
let mut index = 0;
let mut parse_trace: Vec<String> = vec![];
for field in message_descriptor.fields() {
columns.push(pb_field_to_col_desc(&field, &mut index, &mut parse_trace)?);
columns.push(pb_field_to_col_desc(
&field,
&mut index,
&mut parse_trace,
messages_as_jsonb,
)?);
}

Ok(columns)
Expand All @@ -44,15 +54,16 @@ fn pb_field_to_col_desc(
field_descriptor: &FieldDescriptor,
index: &mut i32,
parse_trace: &mut Vec<String>,
messages_as_jsonb: &HashSet<String>,
) -> anyhow::Result<ColumnDesc> {
let field_type = protobuf_type_mapping(field_descriptor, parse_trace)
let field_type = protobuf_type_mapping(field_descriptor, parse_trace, messages_as_jsonb)
.context("failed to map protobuf type")?;
if let Kind::Message(m) = field_descriptor.kind() {
let field_descs = if let DataType::List { .. } = field_type {
vec![]
} else {
m.fields()
.map(|f| pb_field_to_col_desc(&f, index, parse_trace))
.map(|f| pb_field_to_col_desc(&f, index, parse_trace, messages_as_jsonb))
.try_collect()?
};
*index += 1;
Expand Down Expand Up @@ -92,10 +103,12 @@ fn detect_loop_and_push(
let identifier = format!("{}({})", fd.name(), fd.full_name());
if trace.iter().any(|s| s == identifier.as_str()) {
bail_protobuf_type_error!(
"circular reference detected: {}, conflict with {}, kind {:?}",
"circular reference detected: {}, conflict with {}, kind {:?}. Adding {:?} to {:?} may help.",
trace.iter().format("->"),
identifier,
fd.kind(),
fd.kind(),
PROTOBUF_MESSAGES_AS_JSONB,
);
}
trace.push(identifier);
Expand All @@ -106,6 +119,7 @@ pub fn from_protobuf_value<'a>(
field_desc: &FieldDescriptor,
value: &'a Value,
type_expected: &DataType,
messages_as_jsonb: &'a HashSet<String>,
) -> AccessResult<DatumCow<'a>> {
let kind = field_desc.kind();

Expand Down Expand Up @@ -136,7 +150,7 @@ pub fn from_protobuf_value<'a>(
ScalarImpl::Utf8(enum_symbol.name().into())
}
Value::Message(dyn_msg) => {
if dyn_msg.descriptor().full_name() == "google.protobuf.Any" {
if messages_as_jsonb.contains(dyn_msg.descriptor().full_name()) {
ScalarImpl::Jsonb(JsonbVal::from(
serde_json::to_value(dyn_msg).map_err(AccessError::ProtobufAnyToJson)?,
))
Expand All @@ -159,8 +173,13 @@ pub fn from_protobuf_value<'a>(
};
let value = dyn_msg.get_field(&field_desc);
rw_values.push(
from_protobuf_value(&field_desc, &value, expected_field_type)?
.to_owned_datum(),
from_protobuf_value(
&field_desc,
&value,
expected_field_type,
messages_as_jsonb,
)?
.to_owned_datum(),
);
}
ScalarImpl::Struct(StructValue::new(rw_values))
Expand All @@ -176,7 +195,12 @@ pub fn from_protobuf_value<'a>(
};
let mut builder = element_type.create_array_builder(values.len());
for value in values {
builder.append(from_protobuf_value(field_desc, value, element_type)?);
builder.append(from_protobuf_value(
field_desc,
value,
element_type,
messages_as_jsonb,
)?);
}
ScalarImpl::List(ListValue::new(builder.finish()))
}
Expand Down Expand Up @@ -209,11 +233,13 @@ pub fn from_protobuf_value<'a>(
&map_desc.map_entry_key_field(),
&key.clone().into(),
map_type.key(),
messages_as_jsonb,
)?);
value_builder.append(from_protobuf_value(
&map_desc.map_entry_value_field(),
value,
map_type.value(),
messages_as_jsonb,
)?);
}
let keys = key_builder.finish();
Expand All @@ -231,6 +257,7 @@ pub fn from_protobuf_value<'a>(
fn protobuf_type_mapping(
field_descriptor: &FieldDescriptor,
parse_trace: &mut Vec<String>,
messages_as_jsonb: &HashSet<String>,
) -> std::result::Result<DataType, ProtobufTypeError> {
detect_loop_and_push(parse_trace, field_descriptor)?;
let mut t = match field_descriptor.kind() {
Expand All @@ -245,20 +272,33 @@ fn protobuf_type_mapping(
Kind::Uint64 | Kind::Fixed64 => DataType::Decimal,
Kind::String => DataType::Varchar,
Kind::Message(m) => {
if m.full_name() == "google.protobuf.Any" {
if messages_as_jsonb.contains(m.full_name()) {
// Well-Known Types are identified by their full name
DataType::Jsonb
} else if m.is_map_entry() {
// Map is equivalent to `repeated MapFieldEntry map_field = N;`
debug_assert!(field_descriptor.is_map());
let key = protobuf_type_mapping(&m.map_entry_key_field(), parse_trace)?;
let value = protobuf_type_mapping(&m.map_entry_value_field(), parse_trace)?;
let key = protobuf_type_mapping(
&m.map_entry_key_field(),
parse_trace,
messages_as_jsonb,
)?;
let value = protobuf_type_mapping(
&m.map_entry_value_field(),
parse_trace,
messages_as_jsonb,
)?;
_ = parse_trace.pop();
return Ok(DataType::Map(MapType::from_kv(key, value)));
} else {
let fields = m
.fields()
.map(|f| Ok((f.name().to_owned(), protobuf_type_mapping(&f, parse_trace)?)))
.map(|f| {
Ok((
f.name().to_owned(),
protobuf_type_mapping(&f, parse_trace, messages_as_jsonb)?,
))
})
.try_collect::<_, Vec<_>, _>()?;
StructType::new(fields).into()
}
Expand Down
25 changes: 16 additions & 9 deletions src/connector/codec/tests/integration_tests/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ mod recursive;
#[rustfmt::skip]
#[allow(clippy::all)]
mod all_types;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use anyhow::Context;
use prost::Message;
Expand All @@ -39,7 +39,10 @@ fn check(
expected_risingwave_schema: expect_test::Expect,
expected_risingwave_data: expect_test::Expect,
) {
let rw_schema = pb_schema_to_column_descs(&pb_schema);
let rw_schema = pb_schema_to_column_descs(
&pb_schema,
&HashSet::from(["google.protobuf.Any".to_owned()]),
);

if let Err(e) = rw_schema {
expected_risingwave_schema.assert_eq(&e.to_report_string_pretty());
Expand All @@ -58,8 +61,12 @@ fn check(
));

let mut data_str = vec![];
let messages_as_jsonb = HashSet::from(["google.protobuf.Any".to_owned()]);
for data in pb_data {
let access = ProtobufAccess::new(DynamicMessage::decode(pb_schema.clone(), *data).unwrap());
let access = ProtobufAccess::new(
DynamicMessage::decode(pb_schema.clone(), *data).unwrap(),
&messages_as_jsonb,
);
let mut row = vec![];
for col in &rw_schema {
let rw_data = access.access(&[&col.name], &col.data_type);
Expand Down Expand Up @@ -235,11 +242,11 @@ fn test_any_schema() -> anyhow::Result<()> {
// }
static ANY_DATA_3: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x56\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65";

// // id: 12345
// // any_value: {
// // type_url: "type.googleapis.com/test.StringXalue"
// // value: "\n\010John Doe"
// // }
// id: 12345
// any_value: {
// type_url: "type.googleapis.com/test.StringXalue"
// value: "\n\010John Doe"
// }
static ANY_DATA_INVALID: &[u8] = b"\x08\xb9\x60\x12\x32\x0a\x24\x74\x79\x70\x65\x2e\x67\x6f\x6f\x67\x6c\x65\x61\x70\x69\x73\x2e\x63\x6f\x6d\x2f\x74\x65\x73\x74\x2e\x53\x74\x72\x69\x6e\x67\x58\x61\x6c\x75\x65\x12\x0a\x0a\x08\x4a\x6f\x68\x6e\x20\x44\x6f\x65";

// validate the binary data is correct
Expand Down Expand Up @@ -710,7 +717,7 @@ fn test_recursive() -> anyhow::Result<()> {
failed to map protobuf type

Caused by:
circular reference detected: parent(recursive.ComplexRecursiveMessage.parent)->siblings(recursive.ComplexRecursiveMessage.Parent.siblings), conflict with parent(recursive.ComplexRecursiveMessage.parent), kind recursive.ComplexRecursiveMessage.Parent
circular reference detected: parent(recursive.ComplexRecursiveMessage.parent)->siblings(recursive.ComplexRecursiveMessage.Parent.siblings), conflict with parent(recursive.ComplexRecursiveMessage.parent), kind recursive.ComplexRecursiveMessage.Parent. Adding recursive.ComplexRecursiveMessage.Parent to "messages_as_jsonb" may help.
"#]],
expect![""],
);
Expand Down
Loading
Loading