diff --git a/ROADMAP.md b/ROADMAP.md
index 6dd0d909..432caa71 100644
--- a/ROADMAP.md
+++ b/ROADMAP.md
@@ -3,7 +3,7 @@
* Matrix → Signal
* [ ] Message content
* [x] Text
- * [ ] Formatting
+ * [x] Formatting
* [x] Mentions
* [ ] Media
* [x] Images
diff --git a/main.go b/main.go
index c80ff2d9..ad38c185 100644
--- a/main.go
+++ b/main.go
@@ -35,6 +35,7 @@ import (
"go.mau.fi/mautrix-signal/config"
"go.mau.fi/mautrix-signal/database"
+ "go.mau.fi/mautrix-signal/msgconv/matrixfmt"
"go.mau.fi/mautrix-signal/msgconv/signalfmt"
"go.mau.fi/mautrix-signal/pkg/signalmeow"
)
@@ -116,7 +117,7 @@ func (br *SignalBridge) Init() {
br.Metrics = NewMetricsHandler(br.Config.Metrics.Listen, br.Log.Sub("Metrics"), br.DB)
br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent
- formatParams = &signalfmt.FormatParams{
+ signalFormatParams = &signalfmt.FormatParams{
GetUserInfo: func(uuid string) signalfmt.UserInfo {
puppet := br.GetPuppetBySignalID(uuid)
if puppet == nil {
@@ -135,6 +136,20 @@ func (br *SignalBridge) Init() {
}
},
}
+ matrixFormatParams = &matrixfmt.HTMLParser{
+ GetUUIDFromMXID: func(userID id.UserID) string {
+ parsed, ok := br.ParsePuppetMXID(userID)
+ if ok {
+ return parsed
+ }
+ // TODO only get if exists
+ user := br.GetUserByMXID(userID)
+ if user != nil && user.SignalID != "" {
+ return user.SignalID
+ }
+ return ""
+ },
+ }
signalmeow.HackyCaptionToggle = br.Config.Bridge.CaptionInMessage
}
diff --git a/msgconv/matrixfmt/convert.go b/msgconv/matrixfmt/convert.go
new file mode 100644
index 00000000..8e20e521
--- /dev/null
+++ b/msgconv/matrixfmt/convert.go
@@ -0,0 +1,43 @@
+// mautrix-signal - A Matrix-Signal puppeting bridge.
+// Copyright (C) 2023 Tulir Asokan
+//
+// This program is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Affero General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// This program is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Affero General Public License for more details.
+//
+// You should have received a copy of the GNU Affero General Public License
+// along with this program. If not, see .
+
+package matrixfmt
+
+import (
+ "maunium.net/go/mautrix/event"
+
+ signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
+)
+
+func Parse(parser *HTMLParser, content *event.MessageEventContent) (string, []*signalpb.BodyRange) {
+ if content.Format != event.FormatHTML {
+ return content.Body, nil
+ }
+ ctx := NewContext()
+ ctx.AllowedMentions = content.Mentions
+ parsed := parser.Parse(content.FormattedBody, ctx)
+ if parsed == nil {
+ return "", nil
+ }
+ var bodyRanges []*signalpb.BodyRange
+ if len(parsed.Entities) > 0 {
+ bodyRanges = make([]*signalpb.BodyRange, len(parsed.Entities))
+ for i, ent := range parsed.Entities {
+ bodyRanges[i] = ent.Proto()
+ }
+ }
+ return parsed.String.String(), bodyRanges
+}
diff --git a/msgconv/matrixfmt/convert_test.go b/msgconv/matrixfmt/convert_test.go
new file mode 100644
index 00000000..719e4be0
--- /dev/null
+++ b/msgconv/matrixfmt/convert_test.go
@@ -0,0 +1,159 @@
+package matrixfmt_test
+
+import (
+ "fmt"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+
+ "go.mau.fi/mautrix-signal/msgconv/matrixfmt"
+ "go.mau.fi/mautrix-signal/msgconv/signalfmt"
+)
+
+var formatParams = &matrixfmt.HTMLParser{
+ GetUUIDFromMXID: func(id id.UserID) string {
+ if id.Homeserver() == "signal" {
+ return id.Localpart()
+ }
+ return ""
+ },
+}
+
+func TestParse_Empty(t *testing.T) {
+ text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: "",
+ })
+ assert.Equal(t, "", text)
+ assert.Empty(t, entities)
+}
+
+func TestParse_EmptyHTML(t *testing.T) {
+ text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: "",
+ Format: event.FormatHTML,
+ FormattedBody: "",
+ })
+ assert.Equal(t, "", text)
+ assert.Empty(t, entities)
+}
+
+func TestParse_Plaintext(t *testing.T) {
+ text, entities := matrixfmt.Parse(formatParams, &event.MessageEventContent{
+ MsgType: event.MsgText,
+ Body: "Hello world!",
+ })
+ assert.Equal(t, "Hello world!", text)
+ assert.Empty(t, entities)
+}
+
+func TestParse_HTML(t *testing.T) {
+ tests := []struct {
+ name string
+ in string
+ out string
+ ent signalfmt.BodyRangeList
+ }{
+ {name: "Plain", in: "Hello, World!", out: "Hello, World!"},
+ {name: "Basic", in: "Hello, World!", out: "Hello, World!", ent: signalfmt.BodyRangeList{{
+ Start: 0,
+ Length: 5,
+ Value: signalfmt.StyleBold,
+ }}},
+ {
+ name: "MultiBasic",
+ in: "Hello, World!
",
+ out: "Hello, World!",
+ ent: signalfmt.BodyRangeList{{
+ Start: 0,
+ Length: 5,
+ Value: signalfmt.StyleBold,
+ }, {
+ Start: 0,
+ Length: 4,
+ Value: signalfmt.StyleItalic,
+ }, {
+ Start: 7,
+ Length: 5,
+ Value: signalfmt.StyleStrikethrough,
+ }, {
+ Start: 9,
+ Length: 3,
+ Value: signalfmt.StyleSpoiler,
+ }, {
+ Start: 12,
+ Length: 1,
+ Value: signalfmt.StyleMonospace,
+ }},
+ },
+ {
+ name: "TrimSpace",
+ in: " Hello ",
+ out: "Hello",
+ ent: signalfmt.BodyRangeList{{
+ Start: 0,
+ Length: 5,
+ Value: signalfmt.StyleBold,
+ }},
+ },
+ {
+ name: "List",
+ in: "
- woof
- meow
hmm\nmeow
meow
meow
",
+ out: "* woof\n* meow\n* hmm\n meow\n* > meow\n > \n > # meow",
+ ent: signalfmt.BodyRangeList{{
+ Start: 9,
+ Length: 4,
+ Value: signalfmt.StyleBold,
+ }, {
+ Start: 16,
+ Length: 3,
+ Value: signalfmt.StyleMonospace,
+ }, {
+ // FIXME optimally this would be a single range with the previous one so the indent is also monospace
+ Start: 22,
+ Length: 4,
+ Value: signalfmt.StyleMonospace,
+ }, {
+ Start: 45,
+ Length: 6,
+ Value: signalfmt.StyleBold,
+ }},
+ },
+ {
+ name: "OrderedList",
+ in: "- woof
- meow
hmm\nmeow
meow
meow
",
+ out: "9. woof\n10. meow\n11. hmm\n meow\n12. > meow\n > \n > # meow",
+ ent: signalfmt.BodyRangeList{{
+ Start: 13,
+ Length: 4,
+ Value: signalfmt.StyleBold,
+ }, {
+ Start: 22,
+ Length: 3,
+ Value: signalfmt.StyleMonospace,
+ }, {
+ Start: 30,
+ Length: 4,
+ Value: signalfmt.StyleMonospace,
+ }, {
+ Start: 59,
+ Length: 6,
+ Value: signalfmt.StyleBold,
+ }},
+ },
+ }
+ matrixfmt.DebugLog = func(format string, args ...any) {
+ fmt.Printf(format, args...)
+ }
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ fmt.Println("--------------------------------------------------------------------------------")
+ parsed := formatParams.Parse(test.in, matrixfmt.NewContext())
+ assert.Equal(t, test.out, parsed.String.String())
+ assert.Equal(t, test.ent, parsed.Entities)
+ })
+ }
+}
diff --git a/msgconv/matrixfmt/html.go b/msgconv/matrixfmt/html.go
new file mode 100644
index 00000000..28daec6c
--- /dev/null
+++ b/msgconv/matrixfmt/html.go
@@ -0,0 +1,483 @@
+package matrixfmt
+
+import (
+ "fmt"
+ "math"
+ "strconv"
+ "strings"
+
+ "golang.org/x/exp/slices"
+ "golang.org/x/net/html"
+ "maunium.net/go/mautrix/event"
+ "maunium.net/go/mautrix/id"
+
+ "go.mau.fi/mautrix-signal/msgconv/signalfmt"
+)
+
+type EntityString struct {
+ String signalfmt.UTF16String
+ Entities signalfmt.BodyRangeList
+}
+
+var DebugLog = func(format string, args ...any) {}
+
+func NewEntityString(val string) *EntityString {
+ DebugLog("NEW %q\n", val)
+ return &EntityString{
+ String: signalfmt.NewUTF16String(val),
+ }
+}
+
+func (es *EntityString) Split(at uint16) []*EntityString {
+ if at > 0x7F {
+ panic("cannot split at non-ASCII character")
+ }
+ if es == nil {
+ return []*EntityString{}
+ }
+ DebugLog("SPLIT %q %q %+v\n", es.String, rune(at), es.Entities)
+ var output []*EntityString
+ prevSplit := 0
+ doSplit := func(i int) *EntityString {
+ newES := &EntityString{
+ String: es.String[prevSplit:i],
+ }
+ for _, entity := range es.Entities {
+ if (entity.End() <= i || entity.End() > prevSplit) && (entity.Start >= prevSplit || entity.Start < i) {
+ entity = *entity.TruncateStart(prevSplit).TruncateEnd(i).Offset(-prevSplit)
+ if entity.Length > 0 {
+ newES.Entities = append(newES.Entities, entity)
+ }
+ }
+ }
+ return newES
+ }
+ for i, chr := range es.String {
+ if chr != at {
+ continue
+ }
+ newES := doSplit(i)
+ output = append(output, newES)
+ DebugLog(" -> %q %+v\n", newES.String, newES.Entities)
+ prevSplit = i + 1
+ }
+ if prevSplit == 0 {
+ DebugLog(" -> NOOP\n")
+ return []*EntityString{es}
+ }
+ if prevSplit != len(es.String) {
+ newES := doSplit(len(es.String))
+ output = append(output, newES)
+ DebugLog(" -> %q %+v\n", newES.String, newES.Entities)
+ }
+ DebugLog("SPLITEND\n")
+ return output
+}
+
+func (es *EntityString) TrimSpace() *EntityString {
+ if es == nil {
+ return nil
+ }
+ DebugLog("TRIMSPACE %q %+v\n", es.String, es.Entities)
+ var cutEnd, cutStart int
+ for cutStart = 0; cutStart < len(es.String); cutStart++ {
+ switch es.String[cutStart] {
+ case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0:
+ continue
+ }
+ break
+ }
+ for cutEnd = len(es.String) - 1; cutEnd >= 0; cutEnd-- {
+ switch es.String[cutEnd] {
+ case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0:
+ continue
+ }
+ break
+ }
+ cutEnd++
+ if cutStart == 0 && cutEnd == len(es.String) {
+ DebugLog(" -> NOOP\n")
+ return es
+ }
+ newEntities := es.Entities[:0]
+ for _, ent := range es.Entities {
+ ent = *ent.Offset(-cutStart).TruncateEnd(cutEnd)
+ if ent.Length > 0 {
+ newEntities = append(newEntities, ent)
+ }
+ }
+ es.String = es.String[cutStart:cutEnd]
+ es.Entities = newEntities
+ DebugLog(" -> %q %+v\n", es.String, es.Entities)
+ return es
+}
+
+func JoinEntityString(with string, strings ...*EntityString) *EntityString {
+ withUTF16 := signalfmt.NewUTF16String(with)
+ totalLen := 0
+ totalEntities := 0
+ for _, s := range strings {
+ totalLen += len(s.String)
+ totalEntities += len(s.Entities)
+ }
+ str := make(signalfmt.UTF16String, 0, totalLen+len(strings)*len(withUTF16))
+ entities := make(signalfmt.BodyRangeList, 0, totalEntities)
+ DebugLog("JOIN %q %d\n", with, len(strings))
+ for _, s := range strings {
+ if s == nil || len(s.String) == 0 {
+ continue
+ }
+ DebugLog(" + %q %+v\n", s.String, s.Entities)
+ for _, entity := range s.Entities {
+ entity.Start += len(str)
+ entities = append(entities, entity)
+ }
+ str = append(str, s.String...)
+ str = append(str, withUTF16...)
+ }
+ DebugLog(" -> %q %+v\n", str, entities)
+ return &EntityString{
+ String: str,
+ Entities: entities,
+ }
+}
+
+func (es *EntityString) Format(value signalfmt.BodyRangeValue) *EntityString {
+ if es == nil {
+ return nil
+ }
+ newEntity := signalfmt.BodyRange{
+ Start: 0,
+ Length: len(es.String),
+ Value: value,
+ }
+ es.Entities = append(signalfmt.BodyRangeList{newEntity}, es.Entities...)
+ DebugLog("FORMAT %v %q %+v\n", value, es.String, es.Entities)
+ return es
+}
+
+func (es *EntityString) Append(other *EntityString) *EntityString {
+ if es == nil {
+ return other
+ } else if other == nil {
+ return es
+ }
+ DebugLog("APPEND %q %+v\n + %q %+v\n", es.String, es.Entities, other.String, other.Entities)
+ for _, entity := range other.Entities {
+ entity.Start += len(es.String)
+ es.Entities = append(es.Entities, entity)
+ }
+ es.String = append(es.String, other.String...)
+ DebugLog(" -> %q %+v\n", es.String, es.Entities)
+ return es
+}
+
+func (es *EntityString) AppendString(other string) *EntityString {
+ if es == nil {
+ return NewEntityString(other)
+ } else if len(other) == 0 {
+ return es
+ }
+ DebugLog("APPENDSTRING %q %+v\n + %q\n", es.String, es.Entities, other)
+ es.String = append(es.String, signalfmt.NewUTF16String(other)...)
+ DebugLog(" -> %q %+v\n", es.String, es.Entities)
+ return es
+}
+
+type TagStack []string
+
+func (ts TagStack) Index(tag string) int {
+ for i := len(ts) - 1; i >= 0; i-- {
+ if ts[i] == tag {
+ return i
+ }
+ }
+ return -1
+}
+
+func (ts TagStack) Has(tag string) bool {
+ return ts.Index(tag) >= 0
+}
+
+type Context struct {
+ AllowedMentions *event.Mentions
+ TagStack TagStack
+ PreserveWhitespace bool
+}
+
+func NewContext() Context {
+ return Context{
+ TagStack: make(TagStack, 0, 4),
+ }
+}
+
+func (ctx Context) WithTag(tag string) Context {
+ ctx.TagStack = append(ctx.TagStack, tag)
+ return ctx
+}
+
+func (ctx Context) WithWhitespace() Context {
+ ctx.PreserveWhitespace = true
+ return ctx
+}
+
+// HTMLParser is a somewhat customizable Matrix HTML parser.
+type HTMLParser struct {
+ GetUUIDFromMXID func(id.UserID) string
+}
+
+// TaggedString is a string that also contains a HTML tag.
+type TaggedString struct {
+ *EntityString
+ tag string
+}
+
+func (parser *HTMLParser) maybeGetAttribute(node *html.Node, attribute string) (string, bool) {
+ for _, attr := range node.Attr {
+ if attr.Key == attribute {
+ return attr.Val, true
+ }
+ }
+ return "", false
+}
+
+func (parser *HTMLParser) getAttribute(node *html.Node, attribute string) string {
+ val, _ := parser.maybeGetAttribute(node, attribute)
+ return val
+}
+
+// Digits counts the number of digits (and the sign, if negative) in an integer.
+func Digits(num int) int {
+ if num == 0 {
+ return 1
+ } else if num < 0 {
+ return Digits(-num) + 1
+ }
+ return int(math.Floor(math.Log10(float64(num))) + 1)
+}
+
+func (parser *HTMLParser) listToString(node *html.Node, ctx Context) *EntityString {
+ ordered := node.Data == "ol"
+ taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, ctx)
+ counter := 1
+ indentLength := 0
+ if ordered {
+ start := parser.getAttribute(node, "start")
+ if len(start) > 0 {
+ counter, _ = strconv.Atoi(start)
+ }
+
+ longestIndex := (counter - 1) + len(taggedChildren)
+ indentLength = Digits(longestIndex)
+ }
+ indent := strings.Repeat(" ", indentLength+2)
+ var children []*EntityString
+ for _, child := range taggedChildren {
+ if child.tag != "li" {
+ continue
+ }
+ var prefix string
+ // TODO make bullets and numbering configurable
+ if ordered {
+ indexPadding := indentLength - Digits(counter)
+ if indexPadding < 0 {
+ // This will happen on negative start indexes where longestIndex is usually wrong, otherwise shouldn't happen
+ indexPadding = 0
+ }
+ prefix = fmt.Sprintf("%d. %s", counter, strings.Repeat(" ", indexPadding))
+ } else {
+ prefix = "* "
+ }
+ es := NewEntityString(prefix).Append(child.EntityString)
+ counter++
+ parts := es.Split('\n')
+ for i, part := range parts[1:] {
+ parts[i+1] = NewEntityString(indent).Append(part)
+ }
+ children = append(children, parts...)
+ }
+ return JoinEntityString("\n", children...)
+}
+
+func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) *EntityString {
+ str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ switch node.Data {
+ case "b", "strong":
+ return str.Format(signalfmt.StyleBold)
+ case "i", "em":
+ return str.Format(signalfmt.StyleItalic)
+ case "s", "del", "strike":
+ return str.Format(signalfmt.StyleStrikethrough)
+ case "u", "ins":
+ return str
+ case "tt", "code":
+ return str.Format(signalfmt.StyleMonospace)
+ }
+ return str
+}
+
+func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) *EntityString {
+ str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ if node.Data == "span" {
+ _, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler")
+ if isSpoiler {
+ str = str.Format(signalfmt.StyleSpoiler)
+ }
+ }
+ return str
+}
+
+func (parser *HTMLParser) headerToString(node *html.Node, ctx Context) *EntityString {
+ length := int(node.Data[1] - '0')
+ prefix := strings.Repeat("#", length) + " "
+ return NewEntityString(prefix).Append(parser.nodeToString(node.FirstChild, ctx)).Format(signalfmt.StyleBold)
+}
+
+func (parser *HTMLParser) blockquoteToString(node *html.Node, ctx Context) *EntityString {
+ str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ childrenArr := str.TrimSpace().Split('\n')
+ for index, child := range childrenArr {
+ childrenArr[index] = NewEntityString("> ").Append(child)
+ }
+ return JoinEntityString("\n", childrenArr...)
+}
+
+func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityString {
+ str := parser.nodeToTagAwareString(node.FirstChild, ctx)
+ href := parser.getAttribute(node, "href")
+ if len(href) == 0 {
+ return str
+ }
+ parsedMatrix, err := id.ParseMatrixURIOrMatrixToURL(href)
+ if err == nil && parsedMatrix != nil && parsedMatrix.Sigil1 == '@' {
+ mxid := parsedMatrix.UserID()
+ if ctx.AllowedMentions != nil && !slices.Contains(ctx.AllowedMentions.UserIDs, mxid) {
+ // Mention not allowed, use name as-is
+ return str
+ }
+ uuid := parser.GetUUIDFromMXID(mxid)
+ if uuid == "" {
+ // Don't include the link for mentions of non-Signal users, the name is enough
+ return str
+ }
+ return NewEntityString("\uFFFC").Format(signalfmt.Mention{
+ UserInfo: signalfmt.UserInfo{
+ MXID: mxid,
+ Name: str.String.String(),
+ },
+ UUID: uuid,
+ })
+ }
+ if str.String.String() == href {
+ return str
+ }
+ return str.AppendString(fmt.Sprintf(" (%s)", href))
+}
+
+func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) *EntityString {
+ ctx = ctx.WithTag(node.Data)
+ switch node.Data {
+ case "blockquote":
+ return parser.blockquoteToString(node, ctx)
+ case "ol", "ul":
+ return parser.listToString(node, ctx)
+ case "h1", "h2", "h3", "h4", "h5", "h6":
+ return parser.headerToString(node, ctx)
+ case "br":
+ return NewEntityString("\n")
+ case "b", "strong", "i", "em", "s", "strike", "del", "u", "ins", "tt", "code":
+ return parser.basicFormatToString(node, ctx)
+ case "span", "font":
+ return parser.spanToString(node, ctx)
+ case "a":
+ return parser.linkToString(node, ctx)
+ case "p":
+ return parser.nodeToTagAwareString(node.FirstChild, ctx)
+ case "hr":
+ return NewEntityString("---")
+ case "pre":
+ var preStr *EntityString
+ //var language string
+ if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" {
+ //class := parser.getAttribute(node.FirstChild, "class")
+ //if strings.HasPrefix(class, "language-") {
+ // language = class[len("language-"):]
+ //}
+ preStr = parser.nodeToString(node.FirstChild.FirstChild, ctx.WithWhitespace())
+ } else {
+ preStr = parser.nodeToString(node.FirstChild, ctx.WithWhitespace())
+ }
+ return preStr.Format(signalfmt.StyleMonospace)
+ default:
+ return parser.nodeToTagAwareString(node.FirstChild, ctx)
+ }
+}
+
+func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
+ switch node.Type {
+ case html.TextNode:
+ if !ctx.PreserveWhitespace {
+ node.Data = strings.Replace(node.Data, "\n", "", -1)
+ }
+ return TaggedString{NewEntityString(node.Data), "text"}
+ case html.ElementNode:
+ return TaggedString{parser.tagToString(node, ctx), node.Data}
+ case html.DocumentNode:
+ return TaggedString{parser.nodeToTagAwareString(node.FirstChild, ctx), "html"}
+ default:
+ return TaggedString{&EntityString{}, "unknown"}
+ }
+}
+
+func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, ctx Context) (strs []TaggedString) {
+ for ; node != nil; node = node.NextSibling {
+ strs = append(strs, parser.singleNodeToString(node, ctx))
+ }
+ return
+}
+
+var BlockTags = []string{"p", "h1", "h2", "h3", "h4", "h5", "h6", "ol", "ul", "pre", "blockquote", "div", "hr", "table"}
+
+func (parser *HTMLParser) isBlockTag(tag string) bool {
+ for _, blockTag := range BlockTags {
+ if tag == blockTag {
+ return true
+ }
+ }
+ return false
+}
+
+func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, ctx Context) *EntityString {
+ strs := parser.nodeToTaggedStrings(node, ctx)
+ var output *EntityString
+ for _, str := range strs {
+ tstr := str.EntityString
+ if parser.isBlockTag(str.tag) {
+ tstr = NewEntityString("\n").Append(tstr).AppendString("\n")
+ }
+ if output == nil {
+ output = tstr
+ } else {
+ output = output.Append(tstr)
+ }
+ }
+ return output.TrimSpace()
+}
+
+func (parser *HTMLParser) nodeToStrings(node *html.Node, ctx Context) (strs []*EntityString) {
+ for ; node != nil; node = node.NextSibling {
+ strs = append(strs, parser.singleNodeToString(node, ctx).EntityString)
+ }
+ return
+}
+
+func (parser *HTMLParser) nodeToString(node *html.Node, ctx Context) *EntityString {
+ return JoinEntityString("", parser.nodeToStrings(node, ctx)...)
+}
+
+// Parse converts Matrix HTML into text using the settings in this parser.
+func (parser *HTMLParser) Parse(htmlData string, ctx Context) *EntityString {
+ //htmlData = strings.Replace(htmlData, "\t", " ", -1)
+ node, _ := html.Parse(strings.NewReader(htmlData))
+ return parser.nodeToTagAwareString(node, ctx)
+}
diff --git a/msgconv/signalfmt/convert.go b/msgconv/signalfmt/convert.go
index 1abfb2c3..fc1162fe 100644
--- a/msgconv/signalfmt/convert.go
+++ b/msgconv/signalfmt/convert.go
@@ -96,7 +96,7 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) *
// Maybe use NewUTF16String and do index replacements for the plaintext body too,
// or just replace the plaintext body by parsing the generated HTML.
content.Body = strings.Replace(content.Body, "\uFFFC", userInfo.Name, 1)
- br.Value = Mention(userInfo)
+ br.Value = Mention{UserInfo: userInfo, UUID: rv.MentionUuid}
}
lrt.Add(br)
}
diff --git a/msgconv/signalfmt/tags.go b/msgconv/signalfmt/tags.go
index 07756a50..74132a1a 100644
--- a/msgconv/signalfmt/tags.go
+++ b/msgconv/signalfmt/tags.go
@@ -25,14 +25,24 @@ import (
type BodyRangeValue interface {
String() string
Format(message string) string
+ Proto() signalpb.BodyRangeAssociatedValue
}
-type Mention UserInfo
+type Mention struct {
+ UserInfo
+ UUID string
+}
func (m Mention) String() string {
return fmt.Sprintf("Mention{MXID: id.UserID(%q), Name: %q}", m.MXID, m.Name)
}
+func (m Mention) Proto() signalpb.BodyRangeAssociatedValue {
+ return &signalpb.BodyRange_MentionUuid{
+ MentionUuid: m.UUID,
+ }
+}
+
type Style int
const (
@@ -44,8 +54,10 @@ const (
StyleMonospace
)
-func (s Style) Proto() signalpb.BodyRange_Style {
- return signalpb.BodyRange_Style(s)
+func (s Style) Proto() signalpb.BodyRangeAssociatedValue {
+ return &signalpb.BodyRange_Style_{
+ Style: signalpb.BodyRange_Style(s),
+ }
}
func (s Style) String() string {
diff --git a/msgconv/signalfmt/tree.go b/msgconv/signalfmt/tree.go
index e4ec0291..5f37b1c0 100644
--- a/msgconv/signalfmt/tree.go
+++ b/msgconv/signalfmt/tree.go
@@ -16,12 +16,41 @@
package signalfmt
+import (
+ "fmt"
+ "sort"
+
+ "google.golang.org/protobuf/proto"
+
+ signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
+)
+
type BodyRange struct {
Start int
Length int
Value BodyRangeValue
}
+type BodyRangeList []BodyRange
+
+var _ sort.Interface = BodyRangeList(nil)
+
+func (b BodyRangeList) Len() int {
+ return len(b)
+}
+
+func (b BodyRangeList) Less(i, j int) bool {
+ return b[i].Start < b[j].Start || b[i].Length > b[j].Length
+}
+
+func (b BodyRangeList) Swap(i, j int) {
+ b[i], b[j] = b[j], b[i]
+}
+
+func (b BodyRange) String() string {
+ return fmt.Sprintf("%d:%d:%v", b.Start, b.Length, b.Value)
+}
+
// End returns the end index of the range.
func (b BodyRange) End() int {
return b.Start + b.Length
@@ -35,8 +64,10 @@ func (b BodyRange) Offset(offset int) *BodyRange {
// TruncateStart changes the length of the range, so it starts at the given index and ends at the same index as before.
func (b BodyRange) TruncateStart(startAt int) *BodyRange {
- b.Length -= startAt - b.Start
- b.Start = startAt
+ if b.Start < startAt {
+ b.Length -= startAt - b.Start
+ b.Start = startAt
+ }
return &b
}
@@ -48,6 +79,14 @@ func (b BodyRange) TruncateEnd(maxEnd int) *BodyRange {
return &b
}
+func (b BodyRange) Proto() *signalpb.BodyRange {
+ return &signalpb.BodyRange{
+ Start: proto.Uint32(uint32(b.Start)),
+ Length: proto.Uint32(uint32(b.Length)),
+ AssociatedValue: b.Value.Proto(),
+ }
+}
+
// LinkedRangeTree is a linked tree of formatting entities.
//
// It's meant to parse a list of Signal body ranges into nodes that either overlap completely or not at all,
diff --git a/pkg/signalmeow/protobuf/extra.go b/pkg/signalmeow/protobuf/extra.go
new file mode 100644
index 00000000..4475440a
--- /dev/null
+++ b/pkg/signalmeow/protobuf/extra.go
@@ -0,0 +1,3 @@
+package signalpb
+
+type BodyRangeAssociatedValue = isBodyRange_AssociatedValue
diff --git a/pkg/signalmeow/sending.go b/pkg/signalmeow/sending.go
index c24b8427..6bc62543 100644
--- a/pkg/signalmeow/sending.go
+++ b/pkg/signalmeow/sending.go
@@ -411,16 +411,17 @@ func ReadReceptMessageForTimestamps(timestamps []uint64) *SignalContent {
}
}
-func DataMessageForText(text string) *SignalContent {
+func DataMessageForText(text string, ranges []*signalpb.BodyRange) *SignalContent {
timestamp := currentMessageTimestamp()
dm := &signalpb.DataMessage{
- Body: proto.String(text),
- Timestamp: ×tamp,
+ Body: proto.String(text),
+ BodyRanges: ranges,
+ Timestamp: ×tamp,
}
return wrapDataMessageInContent(dm)
}
-func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption string) *SignalContent {
+func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption string, ranges []*signalpb.BodyRange) *SignalContent {
ap := (*signalpb.AttachmentPointer)(attachmentPointer) // Cast back to signalpb, this is okay AttachmentPointer is an alias
timestamp := currentMessageTimestamp()
dm := &signalpb.DataMessage{
diff --git a/portal.go b/portal.go
index d505fb8a..3bbb792d 100644
--- a/portal.go
+++ b/portal.go
@@ -47,12 +47,13 @@ import (
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/crypto/attachment"
"maunium.net/go/mautrix/event"
- "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-signal/database"
+ "go.mau.fi/mautrix-signal/msgconv/matrixfmt"
"go.mau.fi/mautrix-signal/msgconv/signalfmt"
"go.mau.fi/mautrix-signal/pkg/signalmeow"
+ signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf"
)
type portalSignalMessage struct {
@@ -649,34 +650,27 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
switch content.MsgType {
case event.MsgText, event.MsgEmote, event.MsgNotice:
- var text string
- var mentions []string
- if content.Format == event.FormatHTML {
- text, mentions = portal.parseMentionsFromMatrixBody(content)
- } else {
- text = content.Body
- mentions = nil
- }
if content.MsgType == event.MsgNotice && !portal.bridge.Config.Bridge.BridgeNotices {
return nil, errMNoticeDisabled
}
if content.MsgType == event.MsgEmote {
- text = "/me " + text
+ content.Body = "/me " + content.Body
+ if content.FormattedBody != "" {
+ content.FormattedBody = "/me " + content.FormattedBody
+ }
}
+ outgoingMessage = signalmeow.DataMessageForText(matrixfmt.Parse(matrixFormatParams, content))
if ctx.Err() != nil {
return nil, ctx.Err()
}
- outgoingMessage = signalmeow.DataMessageForText(text)
- if mentions != nil && len(mentions) > 0 {
- signalmeow.AddMentionsToDataMessage(outgoingMessage, mentions)
- }
case event.MsgImage:
fileName := content.Body
var caption string
+ var ranges []*signalpb.BodyRange
if content.FileName != "" && content.Body != content.FileName {
fileName = content.FileName
- caption = content.Body
+ caption, ranges = matrixfmt.Parse(matrixFormatParams, content)
}
image, err := portal.downloadAndDecryptMatrixMedia(ctx, content)
if err != nil {
@@ -690,15 +684,9 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if err != nil {
return nil, err
}
- outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption)
+ outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges)
case event.MessageType(event.EventSticker.Type):
- fileName := content.Body
- var caption string
- if content.FileName != "" && content.Body != content.FileName {
- fileName = content.FileName
- caption = content.Body
- }
image, err := portal.downloadAndDecryptMatrixMedia(ctx, content)
if err != nil {
return nil, err
@@ -707,17 +695,18 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if err != nil {
return nil, err
}
- attachmentPointer, err := signalmeow.UploadAttachment(sender.SignalDevice, convertedSticker, newMimeType, fileName)
+ attachmentPointer, err := signalmeow.UploadAttachment(sender.SignalDevice, convertedSticker, newMimeType, content.FileName)
if err != nil {
return nil, err
}
- outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption)
+ outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, "", nil)
case event.MsgVideo:
fileName := content.Body
var caption string
+ var ranges []*signalpb.BodyRange
if content.FileName != "" && content.Body != content.FileName {
fileName = content.FileName
- caption = content.Body
+ caption, ranges = matrixfmt.Parse(matrixFormatParams, content)
}
image, err := portal.downloadAndDecryptMatrixMedia(ctx, content)
if err != nil {
@@ -731,14 +720,15 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if err != nil {
return nil, err
}
- outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption)
+ outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges)
case event.MsgAudio:
fileName := content.Body
var caption string
+ var ranges []*signalpb.BodyRange
if content.FileName != "" && content.Body != content.FileName {
fileName = content.FileName
- caption = content.Body
+ caption, ranges = matrixfmt.Parse(matrixFormatParams, content)
}
image, err := portal.downloadAndDecryptMatrixMedia(ctx, content)
if err != nil {
@@ -752,14 +742,15 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if err != nil {
return nil, err
}
- outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption)
+ outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges)
case event.MsgFile:
fileName := content.Body
var caption string
+ var ranges []*signalpb.BodyRange
if content.FileName != "" && content.Body != content.FileName {
fileName = content.FileName
- caption = content.Body
+ caption, ranges = matrixfmt.Parse(matrixFormatParams, content)
}
file, err := portal.downloadAndDecryptMatrixMedia(ctx, content)
if err != nil {
@@ -769,7 +760,7 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev
if err != nil {
return nil, err
}
- outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption)
+ outgoingMessage = signalmeow.DataMessageForAttachment(attachmentPointer, caption, ranges)
case event.MsgLocation:
fallthrough
@@ -1012,79 +1003,13 @@ func (portal *Portal) addDisappearingMessage(eventID id.EventID, expireInSeconds
portal.bridge.disappearingMessagesManager.AddDisappearingMessage(eventID, portal.MXID, expireInSeconds, startTimerNow)
}
-const mentionedSignalIDsContextKey = "fi.mau.signal.mentioned_ids"
-
-func (portal *Portal) parseMentionsFromMatrixBody(content *event.MessageEventContent) (string, []string) {
- var allowedMentions map[string]bool = nil
- var mentionedSignalIDs []string
-
- // If the matrix event has explicit mentions, we only want to allow parsing mentions from those
- if content.Mentions != nil {
- allowedMentions = make(map[string]bool, len(content.Mentions.UserIDs))
- mentionedSignalIDs = make([]string, 0, len(content.Mentions.UserIDs))
- for _, userID := range content.Mentions.UserIDs {
- var signalID string
- if puppet := portal.bridge.GetPuppetByMXID(userID); puppet != nil {
- signalID = puppet.SignalID
- mentionedSignalIDs = append(mentionedSignalIDs, puppet.SignalID)
- }
- if signalID != "" && !allowedMentions[signalID] {
- allowedMentions[signalID] = true
- mentionedSignalIDs = append(mentionedSignalIDs, signalID)
- }
- }
- }
-
- // Parse what mentions we can find out of the HTML, and replace with unicode replacement character
- matrixHTMLParser := &format.HTMLParser{
- TabsToSpaces: 4,
- Newline: "\n",
-
- PillConverter: func(displayname, mxid, eventID string, ctx format.Context) string {
- if mxid[0] == '@' {
- var signalID string
- if puppet := portal.bridge.GetPuppetByMXID(id.UserID(mxid)); puppet != nil {
- signalID = puppet.SignalID
- }
- if signalID != "" && (allowedMentions == nil || allowedMentions[signalID]) {
- if allowedMentions == nil {
- ids, ok := ctx.ReturnData[mentionedSignalIDsContextKey].([]string)
- if !ok {
- ctx.ReturnData[mentionedSignalIDsContextKey] = []string{signalID}
- } else {
- ctx.ReturnData[mentionedSignalIDsContextKey] = append(ids, signalID)
- }
- }
- // Signal needs the Unicode replacement character, then it will add the name itself
- return "\uFFFC"
- }
- }
- return displayname
- },
- BoldConverter: func(text string, _ format.Context) string { return fmt.Sprintf("*%s*", text) },
- ItalicConverter: func(text string, _ format.Context) string { return fmt.Sprintf("_%s_", text) },
- StrikethroughConverter: func(text string, _ format.Context) string { return fmt.Sprintf("~%s~", text) },
- MonospaceConverter: func(text string, _ format.Context) string { return fmt.Sprintf("```%s```", text) },
- MonospaceBlockConverter: func(text, language string, _ format.Context) string { return fmt.Sprintf("```%s```", text) },
- }
-
- formatContext := format.NewContext()
- parsedBody := matrixHTMLParser.Parse(content.FormattedBody, formatContext)
-
- // If we didn't have any explicit mentions, we can use the ones we parsed from the HTML
- if content.Mentions == nil {
- mentionedSignalIDs, _ = formatContext.ReturnData[mentionedSignalIDsContextKey].([]string)
- }
-
- return parsedBody, mentionedSignalIDs
-}
-
-var formatParams *signalfmt.FormatParams
+var signalFormatParams *signalfmt.FormatParams
+var matrixFormatParams *matrixfmt.HTMLParser
func (portal *Portal) handleSignalTextMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error {
timestamp := portalMessage.message.Base().Timestamp
msg := (portalMessage.message).(signalmeow.IncomingSignalMessageText)
- content := signalfmt.Parse(msg.Content, msg.ContentRanges, formatParams)
+ content := signalfmt.Parse(msg.Content, msg.ContentRanges, signalFormatParams)
portal.addSignalQuote(content, msg.Quote)
resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0)
if err != nil {
@@ -1354,7 +1279,7 @@ func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.Eve
func (portal *Portal) handleSignalAttachmentMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error {
timestamp := portalMessage.message.Base().Timestamp
msg := (portalMessage.message).(signalmeow.IncomingSignalMessageAttachment)
- content := signalfmt.Parse(msg.Caption, msg.CaptionRanges, formatParams)
+ content := signalfmt.Parse(msg.Caption, msg.CaptionRanges, signalFormatParams)
content.Info = &event.FileInfo{
MimeType: msg.ContentType,
Size: int(msg.Size),