From 8260ba1559fe49b4e19e78bb8a9385aee2b94a8c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 21 Dec 2023 16:47:27 +0200 Subject: [PATCH] Add support for Matrix -> Signal formatting --- ROADMAP.md | 2 +- main.go | 17 +- msgconv/matrixfmt/convert.go | 43 +++ msgconv/matrixfmt/convert_test.go | 159 ++++++++++ msgconv/matrixfmt/html.go | 483 ++++++++++++++++++++++++++++++ msgconv/signalfmt/convert.go | 2 +- msgconv/signalfmt/tags.go | 18 +- msgconv/signalfmt/tree.go | 43 ++- pkg/signalmeow/protobuf/extra.go | 3 + pkg/signalmeow/sending.go | 9 +- portal.go | 125 ++------ 11 files changed, 792 insertions(+), 112 deletions(-) create mode 100644 msgconv/matrixfmt/convert.go create mode 100644 msgconv/matrixfmt/convert_test.go create mode 100644 msgconv/matrixfmt/html.go create mode 100644 pkg/signalmeow/protobuf/extra.go diff --git a/ROADMAP.md b/ROADMAP.md index 6dd0d909..432caa71 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -3,7 +3,7 @@ * Matrix → Signal * [ ] Message content * [x] Text - * [ ] Formatting + * [x] Formatting * [x] Mentions * [ ] Media * [x] Images diff --git a/main.go b/main.go index c80ff2d9..ad38c185 100644 --- a/main.go +++ b/main.go @@ -35,6 +35,7 @@ import ( "go.mau.fi/mautrix-signal/config" "go.mau.fi/mautrix-signal/database" + "go.mau.fi/mautrix-signal/msgconv/matrixfmt" "go.mau.fi/mautrix-signal/msgconv/signalfmt" "go.mau.fi/mautrix-signal/pkg/signalmeow" ) @@ -116,7 +117,7 @@ func (br *SignalBridge) Init() { br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB) br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent - formatParams = &signalfmt.FormatParams{ + signalFormatParams = &signalfmt.FormatParams{ GetUserInfo: func(uuid string) signalfmt.UserInfo { puppet := br.GetPuppetBySignalID(uuid) if puppet == nil { @@ -135,6 +136,20 @@ func (br *SignalBridge) Init() { } }, } + matrixFormatParams = &matrixfmt.HTMLParser{ + GetUUIDFromMXID: func(userID id.UserID) string { + parsed, ok := br.ParsePuppetMXID(userID) + if ok { + return parsed + } + // TODO only get if exists + user := br.GetUserByMXID(userID) + if user != nil && user.SignalID != "" { + return user.SignalID + } + return "" + }, + } signalmeow.HackyCaptionToggle = br.Config.Bridge.CaptionInMessage } diff --git a/msgconv/matrixfmt/convert.go b/msgconv/matrixfmt/convert.go new file mode 100644 index 00000000..8e20e521 --- /dev/null +++ b/msgconv/matrixfmt/convert.go @@ -0,0 +1,43 @@ +// mautrix-signal - A Matrix-Signal puppeting bridge. +// Copyright (C) 2023 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package matrixfmt + +import ( + "maunium.net/go/mautrix/event" + + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" +) + +func Parse(parser *HTMLParser, content *event.MessageEventContent) (string, []*signalpb.BodyRange) { + if content.Format != event.FormatHTML { + return content.Body, nil + } + ctx := NewContext() + ctx.AllowedMentions = content.Mentions + parsed := parser.Parse(content.FormattedBody, ctx) + if parsed == nil { + return "", nil + } + var bodyRanges []*signalpb.BodyRange + if len(parsed.Entities) > 0 { + bodyRanges = make([]*signalpb.BodyRange, len(parsed.Entities)) + for i, ent := range parsed.Entities { + bodyRanges[i] = ent.Proto() + } + } + return parsed.String.String(), bodyRanges +} diff --git a/msgconv/matrixfmt/convert_test.go b/msgconv/matrixfmt/convert_test.go new file mode 100644 index 00000000..719e4be0 --- /dev/null +++ b/msgconv/matrixfmt/convert_test.go @@ -0,0 +1,159 @@ +package matrixfmt_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-signal/msgconv/matrixfmt" + "go.mau.fi/mautrix-signal/msgconv/signalfmt" +) + +var formatParams = &matrixfmt.HTMLParser{ + GetUUIDFromMXID: func(id id.UserID) string { + if id.Homeserver() == "signal" { + return id.Localpart() + } + return "" + }, +} + +func TestParse_Empty(t *testing.T) { + text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "", + }) + assert.Equal(t, "", text) + assert.Empty(t, entities) +} + +func TestParse_EmptyHTML(t *testing.T) { + text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "", + Format: event.FormatHTML, + FormattedBody: "", + }) + assert.Equal(t, "", text) + assert.Empty(t, entities) +} + +func TestParse_Plaintext(t *testing.T) { + text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{ + MsgType: event.MsgText, + Body: "Hello world!", + }) + assert.Equal(t, "Hello world!", text) + assert.Empty(t, entities) +} + +func TestParse_HTML(t *testing.T) { + tests := []struct { + name string + in string + out string + ent signalfmt.BodyRangeList + }{ + {name: "Plain", in: "Hello, World!", out: "Hello, World!"}, + {name: "Basic", in: "Hello, World!", out: "Hello, World!", ent: signalfmt.BodyRangeList{{ + Start: 0, + Length: 5, + Value: signalfmt.StyleBold, + }}}, + { + name: "MultiBasic", + in: "Hello, World!", + out: "Hello, World!", + ent: signalfmt.BodyRangeList{{ + Start: 0, + Length: 5, + Value: signalfmt.StyleBold, + }, { + Start: 0, + Length: 4, + Value: signalfmt.StyleItalic, + }, { + Start: 7, + Length: 5, + Value: signalfmt.StyleStrikethrough, + }, { + Start: 9, + Length: 3, + Value: signalfmt.StyleSpoiler, + }, { + Start: 12, + Length: 1, + Value: signalfmt.StyleMonospace, + }}, + }, + { + name: "TrimSpace", + in: " Hello ", + out: "Hello", + ent: signalfmt.BodyRangeList{{ + Start: 0, + Length: 5, + Value: signalfmt.StyleBold, + }}, + }, + { + name: "List", + in: "", + out: "* woof\n* meow\n* hmm\n meow\n* > meow\n > \n > # meow", + ent: signalfmt.BodyRangeList{{ + Start: 9, + Length: 4, + Value: signalfmt.StyleBold, + }, { + Start: 16, + Length: 3, + Value: signalfmt.StyleMonospace, + }, { + // FIXME optimally this would be a single range with the previous one so the indent is also monospace + Start: 22, + Length: 4, + Value: signalfmt.StyleMonospace, + }, { + Start: 45, + Length: 6, + Value: signalfmt.StyleBold, + }}, + }, + { + name: "OrderedList", + in: "
  1. woof
  2. meow
  3. hmm\nmeow
  4. meow

    meow

", + out: "9. woof\n10. meow\n11. hmm\n meow\n12. > meow\n > \n > # meow", + ent: signalfmt.BodyRangeList{{ + Start: 13, + Length: 4, + Value: signalfmt.StyleBold, + }, { + Start: 22, + Length: 3, + Value: signalfmt.StyleMonospace, + }, { + Start: 30, + Length: 4, + Value: signalfmt.StyleMonospace, + }, { + Start: 59, + Length: 6, + Value: signalfmt.StyleBold, + }}, + }, + } + matrixfmt.DebugLog = func(format string, args ...any) { + fmt.Printf(format, args...) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + fmt.Println("--------------------------------------------------------------------------------") + parsed := formatParams.Parse(test.in, matrixfmt.NewContext()) + assert.Equal(t, test.out, parsed.String.String()) + assert.Equal(t, test.ent, parsed.Entities) + }) + } +} diff --git a/msgconv/matrixfmt/html.go b/msgconv/matrixfmt/html.go new file mode 100644 index 00000000..28daec6c --- /dev/null +++ b/msgconv/matrixfmt/html.go @@ -0,0 +1,483 @@ +package matrixfmt + +import ( + "fmt" + "math" + "strconv" + "strings" + + "golang.org/x/exp/slices" + "golang.org/x/net/html" + "maunium.net/go/mautrix/event" + "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-signal/msgconv/signalfmt" +) + +type EntityString struct { + String signalfmt.UTF16String + Entities signalfmt.BodyRangeList +} + +var DebugLog = func(format string, args ...any) {} + +func NewEntityString(val string) *EntityString { + DebugLog("NEW %q\n", val) + return &EntityString{ + String: signalfmt.NewUTF16String(val), + } +} + +func (es *EntityString) Split(at uint16) []*EntityString { + if at > 0x7F { + panic("cannot split at non-ASCII character") + } + if es == nil { + return []*EntityString{} + } + DebugLog("SPLIT %q %q %+v\n", es.String, rune(at), es.Entities) + var output []*EntityString + prevSplit := 0 + doSplit := func(i int) *EntityString { + newES := &EntityString{ + String: es.String[prevSplit:i], + } + for _, entity := range es.Entities { + if (entity.End() <= i || entity.End() > prevSplit) && (entity.Start >= prevSplit || entity.Start < i) { + entity = *entity.TruncateStart(prevSplit).TruncateEnd(i).Offset(-prevSplit) + if entity.Length > 0 { + newES.Entities = append(newES.Entities, entity) + } + } + } + return newES + } + for i, chr := range es.String { + if chr != at { + continue + } + newES := doSplit(i) + output = append(output, newES) + DebugLog(" -> %q %+v\n", newES.String, newES.Entities) + prevSplit = i + 1 + } + if prevSplit == 0 { + DebugLog(" -> NOOP\n") + return []*EntityString{es} + } + if prevSplit != len(es.String) { + newES := doSplit(len(es.String)) + output = append(output, newES) + DebugLog(" -> %q %+v\n", newES.String, newES.Entities) + } + DebugLog("SPLITEND\n") + return output +} + +func (es *EntityString) TrimSpace() *EntityString { + if es == nil { + return nil + } + DebugLog("TRIMSPACE %q %+v\n", es.String, es.Entities) + var cutEnd, cutStart int + for cutStart = 0; cutStart < len(es.String); cutStart++ { + switch es.String[cutStart] { + case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0: + continue + } + break + } + for cutEnd = len(es.String) - 1; cutEnd >= 0; cutEnd-- { + switch es.String[cutEnd] { + case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0: + continue + } + break + } + cutEnd++ + if cutStart == 0 && cutEnd == len(es.String) { + DebugLog(" -> NOOP\n") + return es + } + newEntities := es.Entities[:0] + for _, ent := range es.Entities { + ent = *ent.Offset(-cutStart).TruncateEnd(cutEnd) + if ent.Length > 0 { + newEntities = append(newEntities, ent) + } + } + es.String = es.String[cutStart:cutEnd] + es.Entities = newEntities + DebugLog(" -> %q %+v\n", es.String, es.Entities) + return es +} + +func JoinEntityString(with string, strings ...*EntityString) *EntityString { + withUTF16 := signalfmt.NewUTF16String(with) + totalLen := 0 + totalEntities := 0 + for _, s := range strings { + totalLen += len(s.String) + totalEntities += len(s.Entities) + } + str := make(signalfmt.UTF16String, 0, totalLen+len(strings)*len(withUTF16)) + entities := make(signalfmt.BodyRangeList, 0, totalEntities) + DebugLog("JOIN %q %d\n", with, len(strings)) + for _, s := range strings { + if s == nil || len(s.String) == 0 { + continue + } + DebugLog(" + %q %+v\n", s.String, s.Entities) + for _, entity := range s.Entities { + entity.Start += len(str) + entities = append(entities, entity) + } + str = append(str, s.String...) + str = append(str, withUTF16...) + } + DebugLog(" -> %q %+v\n", str, entities) + return &EntityString{ + String: str, + Entities: entities, + } +} + +func (es *EntityString) Format(value signalfmt.BodyRangeValue) *EntityString { + if es == nil { + return nil + } + newEntity := signalfmt.BodyRange{ + Start: 0, + Length: len(es.String), + Value: value, + } + es.Entities = append(signalfmt.BodyRangeList{newEntity}, es.Entities...) + DebugLog("FORMAT %v %q %+v\n", value, es.String, es.Entities) + return es +} + +func (es *EntityString) Append(other *EntityString) *EntityString { + if es == nil { + return other + } else if other == nil { + return es + } + DebugLog("APPEND %q %+v\n + %q %+v\n", es.String, es.Entities, other.String, other.Entities) + for _, entity := range other.Entities { + entity.Start += len(es.String) + es.Entities = append(es.Entities, entity) + } + es.String = append(es.String, other.String...) + DebugLog(" -> %q %+v\n", es.String, es.Entities) + return es +} + +func (es *EntityString) AppendString(other string) *EntityString { + if es == nil { + return NewEntityString(other) + } else if len(other) == 0 { + return es + } + DebugLog("APPENDSTRING %q %+v\n + %q\n", es.String, es.Entities, other) + es.String = append(es.String, signalfmt.NewUTF16String(other)...) + DebugLog(" -> %q %+v\n", es.String, es.Entities) + return es +} + +type TagStack []string + +func (ts TagStack) Index(tag string) int { + for i := len(ts) - 1; i >= 0; i-- { + if ts[i] == tag { + return i + } + } + return -1 +} + +func (ts TagStack) Has(tag string) bool { + return ts.Index(tag) >= 0 +} + +type Context struct { + AllowedMentions *event.Mentions + TagStack TagStack + PreserveWhitespace bool +} + +func NewContext() Context { + return Context{ + TagStack: make(TagStack, 0, 4), + } +} + +func (ctx Context) WithTag(tag string) Context { + ctx.TagStack = append(ctx.TagStack, tag) + return ctx +} + +func (ctx Context) WithWhitespace() Context { + ctx.PreserveWhitespace = true + return ctx +} + +// HTMLParser is a somewhat customizable Matrix HTML parser. +type HTMLParser struct { + GetUUIDFromMXID func(id.UserID) string +} + +// TaggedString is a string that also contains a HTML tag. +type TaggedString struct { + *EntityString + tag string +} + +func (parser *HTMLParser) maybeGetAttribute(node *html.Node, attribute string) (string, bool) { + for _, attr := range node.Attr { + if attr.Key == attribute { + return attr.Val, true + } + } + return "", false +} + +func (parser *HTMLParser) getAttribute(node *html.Node, attribute string) string { + val, _ := parser.maybeGetAttribute(node, attribute) + return val +} + +// Digits counts the number of digits (and the sign, if negative) in an integer. +func Digits(num int) int { + if num == 0 { + return 1 + } else if num < 0 { + return Digits(-num) + 1 + } + return int(math.Floor(math.Log10(float64(num))) + 1) +} + +func (parser *HTMLParser) listToString(node *html.Node, ctx Context) *EntityString { + ordered := node.Data == "ol" + taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, ctx) + counter := 1 + indentLength := 0 + if ordered { + start := parser.getAttribute(node, "start") + if len(start) > 0 { + counter, _ = strconv.Atoi(start) + } + + longestIndex := (counter - 1) + len(taggedChildren) + indentLength = Digits(longestIndex) + } + indent := strings.Repeat(" ", indentLength+2) + var children []*EntityString + for _, child := range taggedChildren { + if child.tag != "li" { + continue + } + var prefix string + // TODO make bullets and numbering configurable + if ordered { + indexPadding := indentLength - Digits(counter) + if indexPadding < 0 { + // This will happen on negative start indexes where longestIndex is usually wrong, otherwise shouldn't happen + indexPadding = 0 + } + prefix = fmt.Sprintf("%d. %s", counter, strings.Repeat(" ", indexPadding)) + } else { + prefix = "* " + } + es := NewEntityString(prefix).Append(child.EntityString) + counter++ + parts := es.Split('\n') + for i, part := range parts[1:] { + parts[i+1] = NewEntityString(indent).Append(part) + } + children = append(children, parts...) + } + return JoinEntityString("\n", children...) +} + +func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) *EntityString { + str := parser.nodeToTagAwareString(node.FirstChild, ctx) + switch node.Data { + case "b", "strong": + return str.Format(signalfmt.StyleBold) + case "i", "em": + return str.Format(signalfmt.StyleItalic) + case "s", "del", "strike": + return str.Format(signalfmt.StyleStrikethrough) + case "u", "ins": + return str + case "tt", "code": + return str.Format(signalfmt.StyleMonospace) + } + return str +} + +func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) *EntityString { + str := parser.nodeToTagAwareString(node.FirstChild, ctx) + if node.Data == "span" { + _, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler") + if isSpoiler { + str = str.Format(signalfmt.StyleSpoiler) + } + } + return str +} + +func (parser *HTMLParser) headerToString(node *html.Node, ctx Context) *EntityString { + length := int(node.Data[1] - '0') + prefix := strings.Repeat("#", length) + " " + return NewEntityString(prefix).Append(parser.nodeToString(node.FirstChild, ctx)).Format(signalfmt.StyleBold) +} + +func (parser *HTMLParser) blockquoteToString(node *html.Node, ctx Context) *EntityString { + str := parser.nodeToTagAwareString(node.FirstChild, ctx) + childrenArr := str.TrimSpace().Split('\n') + for index, child := range childrenArr { + childrenArr[index] = NewEntityString("> ").Append(child) + } + return JoinEntityString("\n", childrenArr...) +} + +func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityString { + str := parser.nodeToTagAwareString(node.FirstChild, ctx) + href := parser.getAttribute(node, "href") + if len(href) == 0 { + return str + } + parsedMatrix, err := id.ParseMatrixURIOrMatrixToURL(href) + if err == nil && parsedMatrix != nil && parsedMatrix.Sigil1 == '@' { + mxid := parsedMatrix.UserID() + if ctx.AllowedMentions != nil && !slices.Contains(ctx.AllowedMentions.UserIDs, mxid) { + // Mention not allowed, use name as-is + return str + } + uuid := parser.GetUUIDFromMXID(mxid) + if uuid == "" { + // Don't include the link for mentions of non-Signal users, the name is enough + return str + } + return NewEntityString("\uFFFC").Format(signalfmt.Mention{ + UserInfo: signalfmt.UserInfo{ + MXID: mxid, + Name: str.String.String(), + }, + UUID: uuid, + }) + } + if str.String.String() == href { + return str + } + return str.AppendString(fmt.Sprintf(" (%s)", href)) +} + +func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) *EntityString { + ctx = ctx.WithTag(node.Data) + switch node.Data { + case "blockquote": + return parser.blockquoteToString(node, ctx) + case "ol", "ul": + return parser.listToString(node, ctx) + case "h1", "h2", "h3", "h4", "h5", "h6": + return parser.headerToString(node, ctx) + case "br": + return NewEntityString("\n") + case "b", "strong", "i", "em", "s", "strike", "del", "u", "ins", "tt", "code": + return parser.basicFormatToString(node, ctx) + case "span", "font": + return parser.spanToString(node, ctx) + case "a": + return parser.linkToString(node, ctx) + case "p": + return parser.nodeToTagAwareString(node.FirstChild, ctx) + case "hr": + return NewEntityString("---") + case "pre": + var preStr *EntityString + //var language string + if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { + //class := parser.getAttribute(node.FirstChild, "class") + //if strings.HasPrefix(class, "language-") { + // language = class[len("language-"):] + //} + preStr = parser.nodeToString(node.FirstChild.FirstChild, ctx.WithWhitespace()) + } else { + preStr = parser.nodeToString(node.FirstChild, ctx.WithWhitespace()) + } + return preStr.Format(signalfmt.StyleMonospace) + default: + return parser.nodeToTagAwareString(node.FirstChild, ctx) + } +} + +func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString { + switch node.Type { + case html.TextNode: + if !ctx.PreserveWhitespace { + node.Data = strings.Replace(node.Data, "\n", "", -1) + } + return TaggedString{NewEntityString(node.Data), "text"} + case html.ElementNode: + return TaggedString{parser.tagToString(node, ctx), node.Data} + case html.DocumentNode: + return TaggedString{parser.nodeToTagAwareString(node.FirstChild, ctx), "html"} + default: + return TaggedString{&EntityString{}, "unknown"} + } +} + +func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, ctx Context) (strs []TaggedString) { + for ; node != nil; node = node.NextSibling { + strs = append(strs, parser.singleNodeToString(node, ctx)) + } + return +} + +var BlockTags = []string{"p", "h1", "h2", "h3", "h4", "h5", "h6", "ol", "ul", "pre", "blockquote", "div", "hr", "table"} + +func (parser *HTMLParser) isBlockTag(tag string) bool { + for _, blockTag := range BlockTags { + if tag == blockTag { + return true + } + } + return false +} + +func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, ctx Context) *EntityString { + strs := parser.nodeToTaggedStrings(node, ctx) + var output *EntityString + for _, str := range strs { + tstr := str.EntityString + if parser.isBlockTag(str.tag) { + tstr = NewEntityString("\n").Append(tstr).AppendString("\n") + } + if output == nil { + output = tstr + } else { + output = output.Append(tstr) + } + } + return output.TrimSpace() +} + +func (parser *HTMLParser) nodeToStrings(node *html.Node, ctx Context) (strs []*EntityString) { + for ; node != nil; node = node.NextSibling { + strs = append(strs, parser.singleNodeToString(node, ctx).EntityString) + } + return +} + +func (parser *HTMLParser) nodeToString(node *html.Node, ctx Context) *EntityString { + return JoinEntityString("", parser.nodeToStrings(node, ctx)...) +} + +// Parse converts Matrix HTML into text using the settings in this parser. +func (parser *HTMLParser) Parse(htmlData string, ctx Context) *EntityString { + //htmlData = strings.Replace(htmlData, "\t", " ", -1) + node, _ := html.Parse(strings.NewReader(htmlData)) + return parser.nodeToTagAwareString(node, ctx) +} diff --git a/msgconv/signalfmt/convert.go b/msgconv/signalfmt/convert.go index 1abfb2c3..fc1162fe 100644 --- a/msgconv/signalfmt/convert.go +++ b/msgconv/signalfmt/convert.go @@ -96,7 +96,7 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) * // Maybe use NewUTF16String and do index replacements for the plaintext body too, // or just replace the plaintext body by parsing the generated HTML. content.Body = strings.Replace(content.Body, "\uFFFC", userInfo.Name, 1) - br.Value = Mention(userInfo) + br.Value = Mention{UserInfo: userInfo, UUID: rv.MentionUuid} } lrt.Add(br) } diff --git a/msgconv/signalfmt/tags.go b/msgconv/signalfmt/tags.go index 07756a50..74132a1a 100644 --- a/msgconv/signalfmt/tags.go +++ b/msgconv/signalfmt/tags.go @@ -25,14 +25,24 @@ import ( type BodyRangeValue interface { String() string Format(message string) string + Proto() signalpb.BodyRangeAssociatedValue } -type Mention UserInfo +type Mention struct { + UserInfo + UUID string +} func (m Mention) String() string { return fmt.Sprintf("Mention{MXID: id.UserID(%q), Name: %q}", m.MXID, m.Name) } +func (m Mention) Proto() signalpb.BodyRangeAssociatedValue { + return &signalpb.BodyRange_MentionUuid{ + MentionUuid: m.UUID, + } +} + type Style int const ( @@ -44,8 +54,10 @@ const ( StyleMonospace ) -func (s Style) Proto() signalpb.BodyRange_Style { - return signalpb.BodyRange_Style(s) +func (s Style) Proto() signalpb.BodyRangeAssociatedValue { + return &signalpb.BodyRange_Style_{ + Style: signalpb.BodyRange_Style(s), + } } func (s Style) String() string { diff --git a/msgconv/signalfmt/tree.go b/msgconv/signalfmt/tree.go index e4ec0291..5f37b1c0 100644 --- a/msgconv/signalfmt/tree.go +++ b/msgconv/signalfmt/tree.go @@ -16,12 +16,41 @@ package signalfmt +import ( + "fmt" + "sort" + + "google.golang.org/protobuf/proto" + + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" +) + type BodyRange struct { Start int Length int Value BodyRangeValue } +type BodyRangeList []BodyRange + +var _ sort.Interface = BodyRangeList(nil) + +func (b BodyRangeList) Len() int { + return len(b) +} + +func (b BodyRangeList) Less(i, j int) bool { + return b[i].Start < b[j].Start || b[i].Length > b[j].Length +} + +func (b BodyRangeList) Swap(i, j int) { + b[i], b[j] = b[j], b[i] +} + +func (b BodyRange) String() string { + return fmt.Sprintf("%d:%d:%v", b.Start, b.Length, b.Value) +} + // End returns the end index of the range. func (b BodyRange) End() int { return b.Start + b.Length @@ -35,8 +64,10 @@ func (b BodyRange) Offset(offset int) *BodyRange { // TruncateStart changes the length of the range, so it starts at the given index and ends at the same index as before. func (b BodyRange) TruncateStart(startAt int) *BodyRange { - b.Length -= startAt - b.Start - b.Start = startAt + if b.Start < startAt { + b.Length -= startAt - b.Start + b.Start = startAt + } return &b } @@ -48,6 +79,14 @@ func (b BodyRange) TruncateEnd(maxEnd int) *BodyRange { return &b } +func (b BodyRange) Proto() *signalpb.BodyRange { + return &signalpb.BodyRange{ + Start: proto.Uint32(uint32(b.Start)), + Length: proto.Uint32(uint32(b.Length)), + AssociatedValue: b.Value.Proto(), + } +} + // LinkedRangeTree is a linked tree of formatting entities. // // It's meant to parse a list of Signal body ranges into nodes that either overlap completely or not at all, diff --git a/pkg/signalmeow/protobuf/extra.go b/pkg/signalmeow/protobuf/extra.go new file mode 100644 index 00000000..4475440a --- /dev/null +++ b/pkg/signalmeow/protobuf/extra.go @@ -0,0 +1,3 @@ +package signalpb + +type BodyRangeAssociatedValue = isBodyRange_AssociatedValue diff --git a/pkg/signalmeow/sending.go b/pkg/signalmeow/sending.go index c24b8427..6bc62543 100644 --- a/pkg/signalmeow/sending.go +++ b/pkg/signalmeow/sending.go @@ -411,16 +411,17 @@ func ReadReceptMessageForTimestamps(timestamps []uint64) *SignalContent { } } -func DataMessageForText(text string) *SignalContent { +func DataMessageForText(text string, ranges []*signalpb.BodyRange) *SignalContent { timestamp := currentMessageTimestamp() dm := &signalpb.DataMessage{ - Body: proto.String(text), - Timestamp: ×tamp, + Body: proto.String(text), + BodyRanges: ranges, + Timestamp: ×tamp, } return wrapDataMessageInContent(dm) } -func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption string) *SignalContent { +func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption string, ranges []*signalpb.BodyRange) *SignalContent { ap := (*signalpb.AttachmentPointer)(attachmentPointer) // Cast back to signalpb, this is okay AttachmentPointer is an alias timestamp := currentMessageTimestamp() dm := &signalpb.DataMessage{ diff --git a/portal.go b/portal.go index d505fb8a..3bbb792d 100644 --- a/portal.go +++ b/portal.go @@ -47,12 +47,13 @@ import ( "maunium.net/go/mautrix/bridge/status" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/event" - "maunium.net/go/mautrix/format" "maunium.net/go/mautrix/id" "go.mau.fi/mautrix-signal/database" + "go.mau.fi/mautrix-signal/msgconv/matrixfmt" "go.mau.fi/mautrix-signal/msgconv/signalfmt" "go.mau.fi/mautrix-signal/pkg/signalmeow" + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" ) type portalSignalMessage struct { @@ -649,34 +650,27 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev switch content.MsgType { case event.MsgText, event.MsgEmote, event.MsgNotice: - var text string - var mentions []string - if content.Format == event.FormatHTML { - text, mentions = portal.parseMentionsFromMatrixBody(content) - } else { - text = content.Body - mentions = nil - } if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices { return nil, errMNoticeDisabled } if content.MsgType == event.MsgEmote { - text = "/me " + text + content.Body = "/me " + content.Body + if content.FormattedBody != "" { + content.FormattedBody = "/me " + content.FormattedBody + } } + outgoingMessage = signalmeow.DataMessageForText(matrixfmt.Parse(matrixFormatParams, content)) if ctx.Err() != nil { return nil, ctx.Err() } - outgoingMessage = signalmeow.DataMessageForText(text) - if mentions != nil && len(mentions) > 0 { - signalmeow.AddMentionsToDataMessage(outgoingMessage, mentions) - } case event.MsgImage: fileName := content.Body var caption string + var ranges []*signalpb.BodyRange if content.FileName != "" && content.Body != content.FileName { fileName = content.FileName - caption = content.Body + caption, ranges = matrixfmt.Parse(matrixFormatParams, content) } image, err := portal.downloadAndDecryptMatrixMedia(ctx, content) if err != nil { @@ -690,15 +684,9 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if err != nil { return nil, err } - outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption) + outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges) case event.MessageType(event.EventSticker.Type): - fileName := content.Body - var caption string - if content.FileName != "" && content.Body != content.FileName { - fileName = content.FileName - caption = content.Body - } image, err := portal.downloadAndDecryptMatrixMedia(ctx, content) if err != nil { return nil, err @@ -707,17 +695,18 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if err != nil { return nil, err } - attachmentPointer, err := signalmeow.UploadAttachment(sender.SignalDevice, convertedSticker, newMimeType, fileName) + attachmentPointer, err := signalmeow.UploadAttachment(sender.SignalDevice, convertedSticker, newMimeType, content.FileName) if err != nil { return nil, err } - outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption) + outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, "", nil) case event.MsgVideo: fileName := content.Body var caption string + var ranges []*signalpb.BodyRange if content.FileName != "" && content.Body != content.FileName { fileName = content.FileName - caption = content.Body + caption, ranges = matrixfmt.Parse(matrixFormatParams, content) } image, err := portal.downloadAndDecryptMatrixMedia(ctx, content) if err != nil { @@ -731,14 +720,15 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if err != nil { return nil, err } - outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption) + outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges) case event.MsgAudio: fileName := content.Body var caption string + var ranges []*signalpb.BodyRange if content.FileName != "" && content.Body != content.FileName { fileName = content.FileName - caption = content.Body + caption, ranges = matrixfmt.Parse(matrixFormatParams, content) } image, err := portal.downloadAndDecryptMatrixMedia(ctx, content) if err != nil { @@ -752,14 +742,15 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if err != nil { return nil, err } - outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption) + outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges) case event.MsgFile: fileName := content.Body var caption string + var ranges []*signalpb.BodyRange if content.FileName != "" && content.Body != content.FileName { fileName = content.FileName - caption = content.Body + caption, ranges = matrixfmt.Parse(matrixFormatParams, content) } file, err := portal.downloadAndDecryptMatrixMedia(ctx, content) if err != nil { @@ -769,7 +760,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev if err != nil { return nil, err } - outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption) + outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges) case event.MsgLocation: fallthrough @@ -1012,79 +1003,13 @@ func (portal *Portal) addDisappearingMessage(eventID id.EventID, expireInSeconds portal.bridge.disappearingMessagesManager.AddDisappearingMessage(eventID, portal.MXID, expireInSeconds, startTimerNow) } -const mentionedSignalIDsContextKey = "fi.mau.signal.mentioned_ids" - -func (portal *Portal) parseMentionsFromMatrixBody(content *event.MessageEventContent) (string, []string) { - var allowedMentions map[string]bool = nil - var mentionedSignalIDs []string - - // If the matrix event has explicit mentions, we only want to allow parsing mentions from those - if content.Mentions != nil { - allowedMentions = make(map[string]bool, len(content.Mentions.UserIDs)) - mentionedSignalIDs = make([]string, 0, len(content.Mentions.UserIDs)) - for _, userID := range content.Mentions.UserIDs { - var signalID string - if puppet := portal.bridge.GetPuppetByMXID(userID); puppet != nil { - signalID = puppet.SignalID - mentionedSignalIDs = append(mentionedSignalIDs, puppet.SignalID) - } - if signalID != "" && !allowedMentions[signalID] { - allowedMentions[signalID] = true - mentionedSignalIDs = append(mentionedSignalIDs, signalID) - } - } - } - - // Parse what mentions we can find out of the HTML, and replace with unicode replacement character - matrixHTMLParser := &format.HTMLParser{ - TabsToSpaces: 4, - Newline: "\n", - - PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string { - if mxid[0] == '@' { - var signalID string - if puppet := portal.bridge.GetPuppetByMXID(id.UserID(mxid)); puppet != nil { - signalID = puppet.SignalID - } - if signalID != "" && (allowedMentions == nil || allowedMentions[signalID]) { - if allowedMentions == nil { - ids, ok := ctx.ReturnData[mentionedSignalIDsContextKey].([]string) - if !ok { - ctx.ReturnData[mentionedSignalIDsContextKey] = []string{signalID} - } else { - ctx.ReturnData[mentionedSignalIDsContextKey] = append(ids, signalID) - } - } - // Signal needs the Unicode replacement character, then it will add the name itself - return "\uFFFC" - } - } - return displayname - }, - BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) }, - ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) }, - StrikethroughConverter: func(text string, _ format.Context) string { return fmt.Sprintf("~%s~", text) }, - MonospaceConverter: func(text string, _ format.Context) string { return fmt.Sprintf("```%s```", text) }, - MonospaceBlockConverter: func(text, language string, _ format.Context) string { return fmt.Sprintf("```%s```", text) }, - } - - formatContext := format.NewContext() - parsedBody := matrixHTMLParser.Parse(content.FormattedBody, formatContext) - - // If we didn't have any explicit mentions, we can use the ones we parsed from the HTML - if content.Mentions == nil { - mentionedSignalIDs, _ = formatContext.ReturnData[mentionedSignalIDsContextKey].([]string) - } - - return parsedBody, mentionedSignalIDs -} - -var formatParams *signalfmt.FormatParams +var signalFormatParams *signalfmt.FormatParams +var matrixFormatParams *matrixfmt.HTMLParser func (portal *Portal) handleSignalTextMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { timestamp := portalMessage.message.Base().Timestamp msg := (portalMessage.message).(signalmeow.IncomingSignalMessageText) - content := signalfmt.Parse(msg.Content, msg.ContentRanges, formatParams) + content := signalfmt.Parse(msg.Content, msg.ContentRanges, signalFormatParams) portal.addSignalQuote(content, msg.Quote) resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0) if err != nil { @@ -1354,7 +1279,7 @@ func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.Eve func (portal *Portal) handleSignalAttachmentMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { timestamp := portalMessage.message.Base().Timestamp msg := (portalMessage.message).(signalmeow.IncomingSignalMessageAttachment) - content := signalfmt.Parse(msg.Caption, msg.CaptionRanges, formatParams) + content := signalfmt.Parse(msg.Caption, msg.CaptionRanges, signalFormatParams) content.Info = &event.FileInfo{ MimeType: msg.ContentType, Size: int(msg.Size),