From 0348502f027410dfc48b0c72242882a2b01fe74e Mon Sep 17 00:00:00 2001 From: jeffyanta Date: Fri, 23 Aug 2024 09:58:06 -0400 Subject: [PATCH 01/71] Fix Twitter user info update worker when username is changed (#173) --- pkg/code/async/user/twitter.go | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index c8612e9f..591fa1e3 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -171,6 +171,14 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { func (p *service) refreshTwitterUserInfo(ctx context.Context, username string) error { user, err := p.twitterClient.GetUserByUsername(ctx, username) if err != nil { + if strings.Contains(strings.ToLower(err.Error()), "could not find user with username") { + err = p.onTwitterUsernameNotFound(ctx, username) + if err != nil { + return errors.Wrap(err, "error updating cached user state") + } + return nil + } + return errors.Wrap(err, "error getting user info from twitter") } @@ -220,6 +228,25 @@ func (p *service) updateCachedTwitterUser(ctx context.Context, user *twitter_lib } } +func (p *service) onTwitterUsernameNotFound(ctx context.Context, username string) error { + record, err := p.data.GetTwitterUserByUsername(ctx, username) + switch err { + case nil: + case twitter.ErrUserNotFound: + return nil + default: + return errors.Wrap(err, "error getting cached twitter user") + } + + record.LastUpdatedAt = time.Now() + + err = p.data.SaveTwitterUser(ctx, record) + if err != nil { + return errors.Wrap(err, "error updating cached twitter user") + } + return nil +} + func (p *service) findNewRegistrationTweets(ctx context.Context) ([]*twitter_lib.Tweet, error) { var pageToken *string var res []*twitter_lib.Tweet From 0714a15c1093db595719ebe0dc77714e389cace0 Mon Sep 17 00:00:00 2001 From: jeffyanta Date: Mon, 9 Sep 2024 15:58:06 -0400 Subject: [PATCH 02/71] Add auto-reply to Twitter registration tweets (#174) --- go.mod | 1 + go.sum | 2 + pkg/code/async/user/twitter.go | 74 ++++++++++++++++------- pkg/twitter/client.go | 103 +++++++++++++++++++++++++++++---- 4 files changed, 150 insertions(+), 30 deletions(-) diff --git a/go.mod b/go.mod index 8701d46c..d8dcccb6 100644 --- a/go.mod +++ b/go.mod @@ -65,6 +65,7 @@ require ( github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dghubble/oauth1 v0.7.3 // indirect github.com/docker/cli v20.10.7+incompatible // indirect github.com/docker/docker v20.10.7+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect diff --git a/go.sum b/go.sum index ca244823..2dfb305f 100644 --- a/go.sum +++ b/go.sum @@ -141,6 +141,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dghubble/oauth1 v0.7.3 h1:EkEM/zMDMp3zOsX2DC/ZQ2vnEX3ELK0/l9kb+vs4ptE= +github.com/dghubble/oauth1 v0.7.3/go.mod h1:oxTe+az9NSMIucDPDCCtzJGsPhciJV33xocHfcR2sVY= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-sip13 v0.0.0-20181026042036-e10d5fee7954/go.mod h1:vAd38F8PWV+bWy6jNmig1y/TA+kYO4g3RSRF0IAv0no= github.com/docker/cli v20.10.7+incompatible h1:pv/3NqibQKphWZiAskMzdz8w0PRbtTaEB+f6NwdU7Is= diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index 591fa1e3..31e397d0 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -4,6 +4,7 @@ import ( "context" "crypto/ed25519" "database/sql" + "fmt" "strings" "time" @@ -11,6 +12,7 @@ import ( "github.com/mr-tron/base58" "github.com/newrelic/go-agent/v3/newrelic" "github.com/pkg/errors" + "github.com/sirupsen/logrus" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" userpb "github.com/code-payments/code-protobuf-api/generated/go/user/v1" @@ -103,6 +105,8 @@ func (p *service) twitterUserInfoUpdateWorker(serviceCtx context.Context, interv } func (p *service) processNewTwitterRegistrations(ctx context.Context) error { + log := p.log.WithField("method", "processNewTwitterRegistrations") + tweets, err := p.findNewRegistrationTweets(ctx) if err != nil { return errors.Wrap(err, "error finding new registration tweets") @@ -113,6 +117,11 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { return errors.Errorf("author missing in tweet %s", tweet.ID) } + log := log.WithFields(logrus.Fields{ + "tweet": tweet.ID, + "username": tweet.AdditionalMetadata.Author, + }) + // Attempt to find a verified tip account from the registration tweet tipAccount, registrationNonce, err := p.findVerifiedTipAccountRegisteredInTweet(ctx, tweet) switch err { @@ -140,7 +149,21 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { switch err { case nil: - go push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + // todo: all of these success handlers are fire and forget best-effort delivery + + go func() { + err := push_util.SendTwitterAccountConnectedPushNotification(ctx, p.data, p.pusher, tipAccount) + if err != nil { + log.WithError(err).Warn("failure sending success push") + } + }() + + go func() { + err := p.sendRegistrationSuccessReply(ctx, tweet.ID, tweet.AdditionalMetadata.Author.Username) + if err != nil { + log.WithError(err).Warn("failure sending success reply") + } + }() case twitter.ErrDuplicateTipAddress: err = p.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = p.data.MarkTwitterNonceAsUsed(ctx, tweet.ID, *registrationNonce) @@ -228,25 +251,6 @@ func (p *service) updateCachedTwitterUser(ctx context.Context, user *twitter_lib } } -func (p *service) onTwitterUsernameNotFound(ctx context.Context, username string) error { - record, err := p.data.GetTwitterUserByUsername(ctx, username) - switch err { - case nil: - case twitter.ErrUserNotFound: - return nil - default: - return errors.Wrap(err, "error getting cached twitter user") - } - - record.LastUpdatedAt = time.Now() - - err = p.data.SaveTwitterUser(ctx, record) - if err != nil { - return errors.Wrap(err, "error updating cached twitter user") - } - return nil -} - func (p *service) findNewRegistrationTweets(ctx context.Context) ([]*twitter_lib.Tweet, error) { var pageToken *string var res []*twitter_lib.Tweet @@ -361,6 +365,36 @@ func (p *service) findVerifiedTipAccountRegisteredInTweet(ctx context.Context, t return nil, nil, errTwitterRegistrationNotFound } +func (p *service) sendRegistrationSuccessReply(ctx context.Context, regristrationTweetId, username string) error { + // todo: localize this + message := fmt.Sprintf( + "@%s your X account is now connected! Share this link to receive tips: https://tipcard.getcode.com/x/%s", + username, + username, + ) + _, err := p.twitterClient.SendReply(ctx, regristrationTweetId, message) + return err +} + +func (p *service) onTwitterUsernameNotFound(ctx context.Context, username string) error { + record, err := p.data.GetTwitterUserByUsername(ctx, username) + switch err { + case nil: + case twitter.ErrUserNotFound: + return nil + default: + return errors.Wrap(err, "error getting cached twitter user") + } + + record.LastUpdatedAt = time.Now() + + err = p.data.SaveTwitterUser(ctx, record) + if err != nil { + return errors.Wrap(err, "error updating cached twitter user") + } + return nil +} + func toProtoVerifiedType(value string) userpb.TwitterUser_VerifiedType { switch value { case "blue": diff --git a/pkg/twitter/client.go b/pkg/twitter/client.go index 9cf9bb08..6949d2ec 100644 --- a/pkg/twitter/client.go +++ b/pkg/twitter/client.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/dghubble/oauth1" "github.com/pkg/errors" "github.com/code-payments/code-server/pkg/metrics" @@ -28,8 +29,10 @@ const ( type Client struct { httpClient *http.Client - clientId string - clientSecret string + clientId string + clientSecret string + accessToken string + accessTokenSecret string bearerTokenMu sync.RWMutex bearerToken string @@ -37,11 +40,13 @@ type Client struct { } // NewClient returns a new Twitter client -func NewClient(clientId, clientSecret string) *Client { +func NewClient(clientId, clientSecret, accessToken, accessTokenSecret string) *Client { return &Client{ - httpClient: http.DefaultClient, - clientId: clientId, - clientSecret: clientSecret, + httpClient: http.DefaultClient, + clientId: clientId, + clientSecret: clientSecret, + accessToken: accessToken, + accessTokenSecret: accessTokenSecret, } } @@ -143,8 +148,16 @@ func (c *Client) SearchRecentTweets(ctx context.Context, searchString string, ma return tweets, nextToken, err } +// SendReply sends a reply to the provided tweet +func (c *Client) SendReply(ctx context.Context, tweetId, text string) (string, error) { + tracer := metrics.TraceMethodCall(ctx, metricsStructName, "SendReply") + defer tracer.End() + + return c.sendTweet(ctx, text, &tweetId) +} + func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) { - bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret) + bearerToken, err := c.getBearerToken() if err != nil { return nil, err } @@ -189,7 +202,7 @@ func (c *Client) getUser(ctx context.Context, fromUrl string) (*User, error) { } func (c *Client) getTweets(ctx context.Context, fromUrl string) ([]*Tweet, *string, error) { - bearerToken, err := c.getBearerToken(c.clientId, c.clientSecret) + bearerToken, err := c.getBearerToken() if err != nil { return nil, nil, err } @@ -253,7 +266,77 @@ func (c *Client) getTweets(ctx context.Context, fromUrl string) ([]*Tweet, *stri return result.Data, result.Meta.NextToken, nil } -func (c *Client) getBearerToken(clientId, clientSecret string) (string, error) { +func (c *Client) sendTweet(ctx context.Context, text string, inReplyTo *string) (string, error) { + apiUrl := baseUrl + "tweets" + + type ReplyParams struct { + InReplyToTweetId string `json:"in_reply_to_tweet_id"` + } + type Request struct { + Text string `json:"text"` + Reply *ReplyParams `json:"reply"` + } + + reqPayload := Request{ + Text: text, + } + if inReplyTo != nil { + reqPayload.Reply = &ReplyParams{ + InReplyToTweetId: *inReplyTo, + } + } + + reqJson, err := json.Marshal(reqPayload) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", apiUrl, bytes.NewBuffer(reqJson)) + if err != nil { + return "", err + } + + req = req.WithContext(ctx) + + req.Header.Set("Content-Type", "application/json") + + config := oauth1.NewConfig(c.clientId, c.clientSecret) + token := oauth1.NewToken(c.accessToken, c.accessTokenSecret) + httpClient := config.Client(oauth1.NoContext, token) + + resp, err := httpClient.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return "", fmt.Errorf("unexpected http status code: %d", resp.StatusCode) + } + + var result struct { + Data struct { + Id *string `json:"id"` + } `json:"data"` + Errors []*twitterError `json:"errors"` + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + + if err := json.Unmarshal(body, &result); err != nil { + return "", err + } + + if len(result.Errors) > 0 { + return "", result.Errors[0].toError() + } + return *result.Data.Id, nil +} + +func (c *Client) getBearerToken() (string, error) { c.bearerTokenMu.RLock() if time.Since(c.lastBearerTokenRefresh) < bearerTokenMaxAge { c.bearerTokenMu.RUnlock() @@ -275,7 +358,7 @@ func (c *Client) getBearerToken(clientId, clientSecret string) (string, error) { } req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.SetBasicAuth(clientId, clientSecret) + req.SetBasicAuth(c.clientId, c.clientSecret) resp, err := c.httpClient.Do(req) if err != nil { From 8401d3e83b0b1f787519ea5ac25641fcb3422c69 Mon Sep 17 00:00:00 2001 From: jeffyanta Date: Mon, 9 Sep 2024 16:59:15 -0400 Subject: [PATCH 03/71] Handle the case where a Twitter user is suspended in the info update worker (#175) --- pkg/code/async/user/twitter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index 31e397d0..b591874a 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -194,7 +194,7 @@ func (p *service) processNewTwitterRegistrations(ctx context.Context) error { func (p *service) refreshTwitterUserInfo(ctx context.Context, username string) error { user, err := p.twitterClient.GetUserByUsername(ctx, username) if err != nil { - if strings.Contains(strings.ToLower(err.Error()), "could not find user with username") { + if strings.Contains(strings.ToLower(err.Error()), "could not find user with username") || strings.Contains(strings.ToLower(err.Error()), "user has been suspended") { err = p.onTwitterUsernameNotFound(ctx, username) if err != nil { return errors.Wrap(err, "error updating cached user state") From 80b6dd1ff732a7a2c41ef3c901bebcc4cc350394 Mon Sep 17 00:00:00 2001 From: jeffyanta Date: Tue, 10 Sep 2024 11:01:22 -0400 Subject: [PATCH 04/71] Update Twitter connection reply text (#176) --- pkg/code/async/user/twitter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/async/user/twitter.go b/pkg/code/async/user/twitter.go index b591874a..eb1b6f5d 100644 --- a/pkg/code/async/user/twitter.go +++ b/pkg/code/async/user/twitter.go @@ -368,7 +368,7 @@ func (p *service) findVerifiedTipAccountRegisteredInTweet(ctx context.Context, t func (p *service) sendRegistrationSuccessReply(ctx context.Context, regristrationTweetId, username string) error { // todo: localize this message := fmt.Sprintf( - "@%s your X account is now connected! Share this link to receive tips: https://tipcard.getcode.com/x/%s", + "@%s your X account is now connected. Share this link to receive tips: https://tipcard.getcode.com/x/%s", username, username, ) From e7031b5e1cd202a977fde18c268022279a9ae91b Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Tue, 17 Sep 2024 09:57:26 -0400 Subject: [PATCH 05/71] update intent DB models to support pay-for-chat intents (#186) * update intent DB models to support pay-for-chat intents * Update to latest code-protobuf-api, and apply related fixes --- go.mod | 4 +- go.sum | 4 +- pkg/code/chat/message_code_team.go | 4 +- pkg/code/chat/message_kin_purchases.go | 12 ++--- pkg/code/chat/sender_test.go | 4 +- pkg/code/data/intent/intent.go | 16 +++++++ pkg/code/data/intent/postgres/model.go | 38 +++++++++++---- pkg/code/data/intent/postgres/store_test.go | 4 +- pkg/code/data/intent/tests/tests.go | 52 +++++++++++++++++++++ pkg/code/push/notifications.go | 12 ++--- pkg/code/server/grpc/chat/server.go | 10 ++-- pkg/code/server/grpc/chat/server_test.go | 22 ++++----- pkg/code/server/grpc/messaging/server.go | 3 +- pkg/code/server/grpc/messaging/testutil.go | 4 +- pkg/code/server/grpc/transaction/v2/swap.go | 4 +- 15 files changed, 141 insertions(+), 52 deletions(-) diff --git a/go.mod b/go.mod index d8dcccb6..4e808287 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,8 @@ require ( firebase.google.com/go/v4 v4.8.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 - github.com/code-payments/code-protobuf-api v1.16.6 + github.com/code-payments/code-protobuf-api v1.19.0 + github.com/dghubble/oauth1 v0.7.3 github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/golang-jwt/jwt/v5 v5.0.0 @@ -65,7 +66,6 @@ require ( github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dghubble/oauth1 v0.7.3 // indirect github.com/docker/cli v20.10.7+incompatible // indirect github.com/docker/docker v20.10.7+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect diff --git a/go.sum b/go.sum index 2dfb305f..7414b428 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,8 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.16.6 h1:QCot0U+4Ar5SdSX4v955FORMsd3Qcf0ZgkoqlGJZzu0= -github.com/code-payments/code-protobuf-api v1.16.6/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/code-payments/code-protobuf-api v1.19.0 h1:md/eJhqltz8dDY0U8hwT/42C3h+kP+W/68D7RMSjqPo= +github.com/code-payments/code-protobuf-api v1.19.0/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index fe8a7049..536370a3 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -48,8 +48,8 @@ func newIncentiveMessage(localizedTextKey string, intentRecord *intent.Record) ( content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedTextKey, }, }, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index 1ec10b68..f9a2c6fd 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -40,8 +40,8 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei func ToUsdcDepositedMessage(signature string, ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcDeposited, }, }, @@ -60,8 +60,8 @@ func NewUsdcBeingConvertedMessage(ts time.Time) (*chatpb.ChatMessage, error) { content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageUsdcBeingConverted, }, }, @@ -79,8 +79,8 @@ func ToKinAvailableForUseMessage(signature string, ts time.Time, purchases ...*t content := []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.ChatMessageKinAvailableForUse, }, }, diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index 7625a1f3..3438b3c8 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -159,8 +159,8 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { var content []*chatpb.Content for i := 0; i < contentLength; i++ { content = append(content, &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("key%d", rand.Uint32()), }, }, diff --git a/pkg/code/data/intent/intent.go b/pkg/code/data/intent/intent.go index 70a840cc..6c7f5206 100644 --- a/pkg/code/data/intent/intent.go +++ b/pkg/code/data/intent/intent.go @@ -109,9 +109,13 @@ type SendPrivatePaymentMetadata struct { IsRemoteSend bool IsMicroPayment bool IsTip bool + IsChat bool // Set when IsTip = true TipMetadata *TipMetadata + + // Set when IsChat = true + ChatId string } type TipMetadata struct { @@ -578,8 +582,10 @@ func (m *SendPrivatePaymentMetadata) Clone() SendPrivatePaymentMetadata { IsRemoteSend: m.IsRemoteSend, IsMicroPayment: m.IsMicroPayment, IsTip: m.IsTip, + IsChat: m.IsChat, TipMetadata: tipMetadata, + ChatId: m.ChatId, } } @@ -605,8 +611,10 @@ func (m *SendPrivatePaymentMetadata) CopyTo(dst *SendPrivatePaymentMetadata) { dst.IsRemoteSend = m.IsRemoteSend dst.IsMicroPayment = m.IsMicroPayment dst.IsTip = m.IsTip + dst.IsChat = m.IsChat dst.TipMetadata = tipMetadata + dst.ChatId = m.ChatId } func (m *SendPrivatePaymentMetadata) Validate() error { @@ -650,6 +658,14 @@ func (m *SendPrivatePaymentMetadata) Validate() error { return errors.New("tip metadata can only be set for tips") } + if m.IsChat { + if len(m.ChatId) == 0 { + return errors.New("chat_id required for chat") + } + } else if m.ChatId != "" { + return errors.New("chat_id can only be set for chats") + } + return nil } diff --git a/pkg/code/data/intent/postgres/model.go b/pkg/code/data/intent/postgres/model.go index 77b8e7a0..4aa62649 100644 --- a/pkg/code/data/intent/postgres/model.go +++ b/pkg/code/data/intent/postgres/model.go @@ -49,6 +49,8 @@ type intentModel struct { IsTip bool `db:"is_tip"` TipPlatform sql.NullInt16 `db:"tip_platform"` TippedUsername sql.NullString `db:"tipped_username"` + IsChat bool `db:"is_chat"` + ChatId sql.NullString `db:"chat_id"` RelationshipTo sql.NullString `db:"relationship_to"` InitiatorPhoneNumber sql.NullString `db:"phone_number"` // todo: rename the DB field to initiator_phone_number State uint `db:"state"` @@ -106,6 +108,7 @@ func toIntentModel(obj *intent.Record) (*intentModel, error) { m.IsRemoteSend = obj.SendPrivatePaymentMetadata.IsRemoteSend m.IsMicroPayment = obj.SendPrivatePaymentMetadata.IsMicroPayment m.IsTip = obj.SendPrivatePaymentMetadata.IsTip + m.IsChat = obj.SendPrivatePaymentMetadata.IsChat if m.IsTip { m.TipPlatform = sql.NullInt16{ @@ -117,6 +120,13 @@ func toIntentModel(obj *intent.Record) (*intentModel, error) { String: obj.SendPrivatePaymentMetadata.TipMetadata.Username, } } + + if m.IsChat { + m.ChatId = sql.NullString{ + Valid: true, + String: obj.SendPrivatePaymentMetadata.ChatId, + } + } case intent.ReceivePaymentsPrivately: m.Source = obj.ReceivePaymentsPrivatelyMetadata.Source m.Quantity = obj.ReceivePaymentsPrivatelyMetadata.Quantity @@ -224,6 +234,7 @@ func fromIntentModel(obj *intentModel) *intent.Record { IsRemoteSend: obj.IsRemoteSend, IsMicroPayment: obj.IsMicroPayment, IsTip: obj.IsTip, + IsChat: obj.IsChat, } if record.SendPrivatePaymentMetadata.IsTip { @@ -232,6 +243,11 @@ func fromIntentModel(obj *intentModel) *intent.Record { Username: obj.TippedUsername.String, } } + + if record.SendPrivatePaymentMetadata.IsChat { + record.SendPrivatePaymentMetadata.ChatId = obj.ChatId.String + } + case intent.ReceivePaymentsPrivately: record.ReceivePaymentsPrivatelyMetadata = &intent.ReceivePaymentsPrivatelyMetadata{ Source: obj.Source, @@ -300,16 +316,16 @@ func fromIntentModel(obj *intentModel) *intent.Record { func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { query := `INSERT INTO ` + intentTableName + ` - (intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + (intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28) ON CONFLICT (intent_id) DO UPDATE - SET state = $25 + SET state = $27 WHERE ` + intentTableName + `.intent_id = $1 RETURNING - id, intent_id, intent_type, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at` + id, intent_id, intent_type, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at` err := tx.QueryRowxContext( ctx, @@ -334,9 +350,11 @@ func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { m.IsIssuerVoidingGiftCard, m.IsMicroPayment, m.IsTip, + m.IsChat, m.RelationshipTo, m.TipPlatform, m.TippedUsername, + m.ChatId, m.InitiatorPhoneNumber, m.State, m.CreatedAt, @@ -349,7 +367,7 @@ func (m *intentModel) dbSave(ctx context.Context, db *sqlx.DB) error { func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE intent_id = $1 LIMIT 1` @@ -364,7 +382,7 @@ func dbGetIntent(ctx context.Context, db *sqlx.DB, intentID string) (*intentMode func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType intent.Type, owner string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE owner = $1 AND intent_type = $2 ORDER BY created_at DESC @@ -381,7 +399,7 @@ func dbGetLatestByInitiatorAndType(ctx context.Context, db *sqlx.DB, intentType func dbGetAllByOwner(ctx context.Context, db *sqlx.DB, owner string, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE (owner = $1 OR destination_owner = $1) AND (intent_type != $2 AND intent_type != $3) ` @@ -542,7 +560,7 @@ func dbGetNetBalanceFromPrePrivacy2022Intents(ctx context.Context, db *sqlx.DB, func dbGetLatestSaveRecentRootIntentForTreasury(ctx context.Context, db *sqlx.DB, treasury string) (*intentModel, error) { res := &intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE treasury_pool = $1 and intent_type = $2 ORDER BY id DESC @@ -559,7 +577,7 @@ func dbGetLatestSaveRecentRootIntentForTreasury(ctx context.Context, db *sqlx.DB func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE destination = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 @@ -591,7 +609,7 @@ func dbGetOriginalGiftCardIssuedIntent(ctx context.Context, db *sqlx.DB, giftCar func dbGetGiftCardClaimedIntent(ctx context.Context, db *sqlx.DB, giftCardVault string) (*intentModel, error) { res := []*intentModel{} - query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, relationship_to, tip_platform, tipped_username, phone_number, state, created_at + query := `SELECT id, intent_id, intent_type, owner, source, destination_owner, destination, quantity, treasury_pool, recent_root, exchange_currency, exchange_rate, native_amount, usd_market_value, is_withdraw, is_deposit, is_remote_send, is_returned, is_issuer_voiding_gift_card, is_micro_payment, is_tip, is_chat, relationship_to, tip_platform, tipped_username, chat_id, phone_number, state, created_at FROM ` + intentTableName + ` WHERE source = $1 and intent_type = $2 AND state != $3 AND is_remote_send IS TRUE LIMIT 2 diff --git a/pkg/code/data/intent/postgres/store_test.go b/pkg/code/data/intent/postgres/store_test.go index fdaae93b..18a7d1c9 100644 --- a/pkg/code/data/intent/postgres/store_test.go +++ b/pkg/code/data/intent/postgres/store_test.go @@ -47,13 +47,15 @@ const ( is_issuer_voiding_gift_card BOOL NOT NULL, is_micro_payment BOOL NOT NULL, is_tip BOOL NOT NULL, + is_chat BOOL NOT NULL, relationship_to TEXT NULL, tip_platform INTEGER NULL, tipped_username TEXT NULL, - phone_number text NULL, + chat_id TEXT NULL, + phone_number TEXT NULL, state integer NOT NULL, diff --git a/pkg/code/data/intent/tests/tests.go b/pkg/code/data/intent/tests/tests.go index c17af4cb..2769b092 100644 --- a/pkg/code/data/intent/tests/tests.go +++ b/pkg/code/data/intent/tests/tests.go @@ -37,6 +37,7 @@ func RunTests(t *testing.T, s intent.Store, teardown func()) { testGetLatestSaveRecentRootIntentForTreasury, testGetOriginalGiftCardIssuedIntent, testGetGiftCardClaimedIntent, + testChatPayment, } { tf(t, s) teardown() @@ -1011,3 +1012,54 @@ func testGetGiftCardClaimedIntent(t *testing.T, s intent.Store) { assert.Equal(t, "i9", actual.IntentId) }) } + +func testChatPayment(t *testing.T, s intent.Store) { + t.Run("testChatPayment", func(t *testing.T) { + record := &intent.Record{ + IntentId: "i1", + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: "init1", + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationOwnerAccount: "do", + DestinationTokenAccount: "dt", + Quantity: 1, + ExchangeCurrency: "USD", + ExchangeRate: 1, + NativeAmount: 1, + UsdMarketValue: 1, + IsChat: true, + ChatId: "chatId", + }, + } + require.NoError(t, s.Save(context.Background(), record)) + + saved, err := s.Get(context.Background(), record.IntentId) + require.NoError(t, err) + require.Equal(t, record, saved) + }) + + t.Run("testChatPayment invalid", func(t *testing.T) { + base := &intent.Record{ + IntentId: "i1", + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: "init1", + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationOwnerAccount: "do", + DestinationTokenAccount: "dt", + Quantity: 1, + ExchangeCurrency: "USD", + ExchangeRate: 1, + NativeAmount: 1, + UsdMarketValue: 1, + }, + } + + r := base.Clone() + r.SendPrivatePaymentMetadata.IsChat = true + require.Error(t, s.Save(context.Background(), &r)) + + r = base.Clone() + r.SendPrivatePaymentMetadata.ChatId = "chatId" + require.Error(t, s.Save(context.Background(), &r)) + }) +} diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index ccc361fb..42aff9be 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -320,15 +320,15 @@ func SendChatMessagePushNotification( for _, content := range chatMessage.Content { var contentToPush *chatpb.Content switch typedContent := content.Type.(type) { - case *chatpb.Content_Localized: - localizedPushBody, err := localization.Localize(locale, typedContent.Localized.KeyOrText) + case *chatpb.Content_ServerLocalized: + localizedPushBody, err := localization.Localize(locale, typedContent.ServerLocalized.KeyOrText) if err != nil { continue } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedPushBody, }, }, @@ -358,8 +358,8 @@ func SendChatMessagePushNotification( } contentToPush = &chatpb.Content{ - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: localizedPushBody, }, }, diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index 362c0fd7..17c201ca 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -147,7 +147,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } protoMetadata.Title = &chatpb.ChatMetadata_Localized{ - Localized: &chatpb.LocalizedContent{ + Localized: &chatpb.ServerLocalizedContent{ KeyOrText: localization.LocalizeWithFallback( locale, localization.GetLocalizationKeyForUserAgent(ctx, chatProperties.TitleLocalizationKey), @@ -298,11 +298,11 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest for _, content := range protoChatMessage.Content { switch typed := content.Type.(type) { - case *chatpb.Content_Localized: - typed.Localized.KeyOrText = localization.LocalizeWithFallback( + case *chatpb.Content_ServerLocalized: + typed.ServerLocalized.KeyOrText = localization.LocalizeWithFallback( locale, - localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), - typed.Localized.KeyOrText, + localization.GetLocalizationKeyForUserAgent(ctx, typed.ServerLocalized.KeyOrText), + typed.ServerLocalized.KeyOrText, ) } } diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/server_test.go index f6dd6eff..627764b0 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/server_test.go @@ -102,8 +102,8 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -242,7 +242,7 @@ func TestGetChatsAndMessages_HappyPath(t *testing.T) { require.Len(t, getMessagesResp.Messages, 1) assert.Equal(t, expectedCodeTeamMessage.MessageId.Value, getMessagesResp.Messages[0].Cursor.Value) getMessagesResp.Messages[0].Cursor = nil - expectedCodeTeamMessage.Content[0].GetLocalized().KeyOrText = "localized message body content" + expectedCodeTeamMessage.Content[0].GetServerLocalized().KeyOrText = "localized message body content" assert.True(t, proto.Equal(expectedCodeTeamMessage, getMessagesResp.Messages[0])) getMessagesResp, err = env.client.GetMessages(env.ctx, getCashTransactionsMessagesReq) @@ -288,8 +288,8 @@ func TestChatHistoryReadState_HappyPath(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -346,8 +346,8 @@ func TestChatHistoryReadState_NegativeProgress(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: fmt.Sprintf("msg.body.key%d", i), }, }, @@ -429,8 +429,8 @@ func TestChatHistoryReadState_MessageNotFound(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, @@ -743,8 +743,8 @@ func TestUnauthorizedAccess(t *testing.T) { Ts: timestamppb.Now(), Content: []*chatpb.Content{ { - Type: &chatpb.Content_Localized{ - Localized: &chatpb.LocalizedContent{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ KeyOrText: "msg.body.key", }, }, diff --git a/pkg/code/server/grpc/messaging/server.go b/pkg/code/server/grpc/messaging/server.go index c89f69ce..85308c09 100644 --- a/pkg/code/server/grpc/messaging/server.go +++ b/pkg/code/server/grpc/messaging/server.go @@ -17,6 +17,7 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" "github.com/code-payments/code-server/pkg/cache" @@ -285,7 +286,7 @@ func (s *server) OpenMessageStreamWithKeepAlive(streamer messagingpb.Messaging_O err := streamer.Send(&messagingpb.OpenMessageStreamWithKeepAliveResponse{ ResponseOrPing: &messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping{ - Ping: &messagingpb.ServerPing{ + Ping: &commonpb.ServerPing{ Timestamp: timestamppb.Now(), PingDelay: durationpb.New(messageStreamPingDelay), }, diff --git a/pkg/code/server/grpc/messaging/testutil.go b/pkg/code/server/grpc/messaging/testutil.go index f1e7cf97..4968ab2a 100644 --- a/pkg/code/server/grpc/messaging/testutil.go +++ b/pkg/code/server/grpc/messaging/testutil.go @@ -373,7 +373,7 @@ func (c *clientEnv) receiveMessagesInRealTime(t *testing.T, rendezvousKey *commo case *messagingpb.OpenMessageStreamWithKeepAliveResponse_Ping: require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, @@ -467,7 +467,7 @@ func (c *clientEnv) waitUntilStreamTerminationOrTimeout(t *testing.T, rendezvous if keepStreamAlive { require.NoError(t, streamer.streamWithKeepAlives.Send(&messagingpb.OpenMessageStreamWithKeepAliveRequest{ RequestOrPong: &messagingpb.OpenMessageStreamWithKeepAliveRequest_Pong{ - Pong: &messagingpb.ClientPong{ + Pong: &commonpb.ClientPong{ Timestamp: timestamppb.Now(), }, }, diff --git a/pkg/code/server/grpc/transaction/v2/swap.go b/pkg/code/server/grpc/transaction/v2/swap.go index bfc76556..4f05ab2c 100644 --- a/pkg/code/server/grpc/transaction/v2/swap.go +++ b/pkg/code/server/grpc/transaction/v2/swap.go @@ -521,8 +521,8 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con } switch typed := protoChatMessage.Content[0].Type.(type) { - case *chatpb.Content_Localized: - if typed.Localized.KeyOrText != localization.ChatMessageUsdcDeposited { + case *chatpb.Content_ServerLocalized: + if typed.ServerLocalized.KeyOrText != localization.ChatMessageUsdcDeposited { return nil } } From c55e811d7265981ce39c13c8640f39561fdfc603 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 15 May 2024 16:09:26 -0400 Subject: [PATCH 06/71] PoC in memory two way messaging --- pkg/code/server/grpc/chat/server.go | 251 ++++++++++++++++++++++- pkg/code/server/grpc/chat/stream.go | 121 +++++++++++ pkg/code/server/grpc/messaging/server.go | 7 +- 3 files changed, 364 insertions(+), 15 deletions(-) create mode 100644 pkg/code/server/grpc/chat/stream.go diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index 17c201ca..b66ac9fe 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -1,14 +1,20 @@ package chat import ( + "bytes" "context" + "fmt" "math" + "strings" + "sync" + "time" "github.com/mr-tron/base58" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" @@ -28,19 +34,28 @@ const ( maxPageSize = 100 ) +var ( + mockTwoWayChat = chat.GetChatId("user1", "user2", true).ToProto() +) + +// todo: Resolve duplication of streaming logic with messaging service. The latest and greatest will live here. type server struct { log *logrus.Entry data code_data.Provider auth *auth_util.RPCSignatureVerifier + streamsMu sync.RWMutex + streams map[string]*chatEventStream + chatpb.UnimplementedChatServer } func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { return &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), - data: data, - auth: auth, + log: logrus.StandardLogger().WithField("type", "chat/server"), + data: data, + auth: auth, + streams: make(map[string]*chatEventStream), } } @@ -347,23 +362,54 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId := chat.ChatIdFromProto(req.ChatId) - log = log.WithField("chat_id", chatId.String()) + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + chatId := chat.ChatIdFromProto(req.ChatId) messageId := base58.Encode(req.Pointer.Value.Value) log = log.WithFields(logrus.Fields{ + "chat_id": chatId.String(), "message_id": messageId, "pointer_type": req.Pointer.Kind, }) - if req.Pointer.Kind != chatpb.Pointer_READ { - return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") + // todo: Temporary code to simluate real-time + if req.Pointer.User != nil { + return nil, status.Error(codes.InvalidArgument, "pointer.user cannot be set by clients") } + if bytes.Equal(mockTwoWayChat.Value, req.ChatId.Value) { + req.Pointer.User = &chatpb.ChatMemberId{Value: req.Owner.Value} - signature := req.Signature - req.Signature = nil - if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { - return nil, err + event := &chatpb.ChatStreamEvent{ + Pointers: []*chatpb.Pointer{req.Pointer}, + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil + } + + if req.Pointer.Kind != chatpb.Pointer_READ { + return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } chatRecord, err := s.data.GetChatById(ctx, chatId) @@ -529,3 +575,186 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr Result: chatpb.SetSubscriptionStateResponse_OK, }, nil } + +// +// Experimental PoC two-way chat APIs below +// + +func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { + ctx := streamer.Context() + + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) + if err != nil { + return err + } + + if req.GetOpenStream() == nil { + return status.Error(codes.InvalidArgument, "open_stream is nil") + } + + if req.GetOpenStream().Signature == nil { + return status.Error(codes.InvalidArgument, "signature is nil") + } + + if !bytes.Equal(req.GetOpenStream().ChatId.Value, mockTwoWayChat.Value) { + return status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.GetOpenStream().ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.GetOpenStream().Signature + req.GetOpenStream().Signature = nil + if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), owner.PublicKey().ToBase58()) + + s.streamsMu.Lock() + + stream, exists := s.streams[streamKey] + if exists { + s.streamsMu.Unlock() + // There's an existing stream on this server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + stream = newChatEventStream(streamBufferSize) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = stream + + s.streamsMu.Unlock() + + sendPingCh := time.After(0) + streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + + for { + select { + case event, ok := <-stream.streamCh: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: &chatpb.ChatStreamEventBatch{ + Events: []*chatpb.ChatStreamEvent{event}, + }, + }, + }) + if err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + if !bytes.Equal(req.ChatId.Value, mockTwoWayChat.Value) { + return nil, status.Error(codes.Unimplemented, "") + } + chatId := chat.ChatIdFromProto(req.ChatId) + log = log.WithField("chat_id", chatId.String()) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + switch req.Content[0].Type.(type) { + case *chatpb.Content_UserText: + default: + return nil, status.Error(codes.InvalidArgument, "content[0] must be UserText") + } + + // todo: Revisit message IDs + messageId, err := common.NewRandomAccount() + if err != nil { + log.WithError(err).Warn("failure generating random message id") + return nil, status.Error(codes.Internal, "") + } + + chatMessage := &chatpb.ChatMessage{ + MessageId: &chatpb.ChatMessageId{Value: messageId.ToProto().Value}, + Ts: timestamppb.Now(), + Content: req.Content, + Sender: &chatpb.ChatMemberId{Value: req.Owner.Value}, + Cursor: nil, // todo: Don't have cursor until we save it to the DB + } + + event := &chatpb.ChatStreamEvent{ + Messages: []*chatpb.ChatMessage{chatMessage}, + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go new file mode 100644 index 00000000..9d79969b --- /dev/null +++ b/pkg/code/server/grpc/chat/stream.go @@ -0,0 +1,121 @@ +package chat + +import ( + "context" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = 10 * time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/messaging/server.go b/pkg/code/server/grpc/messaging/server.go index 85308c09..87a35c3e 100644 --- a/pkg/code/server/grpc/messaging/server.go +++ b/pkg/code/server/grpc/messaging/server.go @@ -21,16 +21,15 @@ import ( messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" "github.com/code-payments/code-server/pkg/cache" - "github.com/code-payments/code-server/pkg/grpc/client" - "github.com/code-payments/code-server/pkg/retry" - "github.com/code-payments/code-server/pkg/retry/backoff" - "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/messaging" "github.com/code-payments/code-server/pkg/code/data/rendezvous" "github.com/code-payments/code-server/pkg/code/thirdparty" + "github.com/code-payments/code-server/pkg/grpc/client" + "github.com/code-payments/code-server/pkg/retry" + "github.com/code-payments/code-server/pkg/retry/backoff" ) const ( From 081cbf31897c7ae8de6faa7b31608e8c752f6b30 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 17 May 2024 11:44:02 -0400 Subject: [PATCH 07/71] Move chat event stream notification into an async worker --- pkg/code/server/grpc/chat/server.go | 61 +++++++++++++---------------- pkg/code/server/grpc/chat/stream.go | 58 ++++++++++++++++++++++++++- 2 files changed, 84 insertions(+), 35 deletions(-) diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index b66ac9fe..d76dc50d 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -5,7 +5,6 @@ import ( "context" "fmt" "math" - "strings" "sync" "time" @@ -28,6 +27,7 @@ import ( "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" + sync_util "github.com/code-payments/code-server/pkg/sync" ) const ( @@ -47,16 +47,27 @@ type server struct { streamsMu sync.RWMutex streams map[string]*chatEventStream + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel + chatpb.UnimplementedChatServer } func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - return &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), - data: data, - auth: auth, - streams: make(map[string]*chatEventStream), + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/server"), + data: data, + auth: auth, + streams: make(map[string]*chatEventStream), + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters + } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) } + + return s } func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { @@ -387,21 +398,9 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR Pointers: []*chatpb.Pointer{req.Pointer}, } - s.streamsMu.RLock() - for key, stream := range s.streams { - if !strings.HasPrefix(key, chatId.String()) { - continue - } - - if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { - continue - } - - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - } + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - s.streamsMu.RUnlock() return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_OK, @@ -718,6 +717,10 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, status.Error(codes.InvalidArgument, "content[0] must be UserText") } + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + // todo: Revisit message IDs messageId, err := common.NewRandomAccount() if err != nil { @@ -733,25 +736,15 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest Cursor: nil, // todo: Don't have cursor until we save it to the DB } + // todo: Save the message to the DB + event := &chatpb.ChatStreamEvent{ Messages: []*chatpb.ChatMessage{chatMessage}, } - s.streamsMu.RLock() - for key, stream := range s.streams { - if !strings.HasPrefix(key, chatId.String()) { - continue - } - - if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { - continue - } - - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - } + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - s.streamsMu.RUnlock() return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go index 9d79969b..d29bf7c8 100644 --- a/pkg/code/server/grpc/chat/stream.go +++ b/pkg/code/server/grpc/chat/stream.go @@ -2,6 +2,7 @@ package chat import ( "context" + "strings" "sync" "time" @@ -12,6 +13,8 @@ import ( "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + "github.com/code-payments/code-server/pkg/code/common" + "github.com/code-payments/code-server/pkg/code/data/chat" ) const ( @@ -19,7 +22,7 @@ const ( streamBufferSize = 64 streamPingDelay = 5 * time.Second streamKeepAliveRecvTimeout = 10 * time.Second - streamNotifyTimeout = 10 * time.Second + streamNotifyTimeout = time.Second ) type chatEventStream struct { @@ -90,6 +93,59 @@ func boundedStreamChatEventsRecv( } } +type chatIdWithEvent struct { + chatId chat.ChatId + owner *common.Account + event *chatpb.ChatStreamEvent + ts time.Time +} + +func (s *server) asyncNotifyAll(chatId chat.ChatId, owner *common.Account, event *chatpb.ChatStreamEvent) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + ok := s.chatEventChans.Send(chatId[:], &chatIdWithEvent{chatId, owner, m, time.Now()}) + if !ok { + return errors.New("chat event channel is full") + } + return nil +} + +func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatIdWithEvent) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log := log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("") + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, typedValue.chatId.String()) { + continue + } + + if strings.HasSuffix(key, typedValue.owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} + // Very naive implementation to start func monitorChatEventStreamHealth( ctx context.Context, From f112a51ff3c594e183c08c0bb3a36a6dc1cc6c41 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 17 May 2024 15:53:03 -0400 Subject: [PATCH 08/71] Ensure chat event streams are cleaned up after being closed --- pkg/code/server/grpc/chat/server.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index d76dc50d..2fe6d33f 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -639,6 +639,22 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e s.streamsMu.Unlock() + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another OpenMessageStream() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) From 9a8c990e2fbefc8537264f0702ba549663d98ce2 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 21 May 2024 14:11:55 -0400 Subject: [PATCH 09/71] Add support for thank you messages --- pkg/code/server/grpc/chat/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index 2fe6d33f..9614c124 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -728,9 +728,9 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } switch req.Content[0].Type.(type) { - case *chatpb.Content_UserText: + case *chatpb.Content_Text, *chatpb.Content_ThankYou: default: - return nil, status.Error(codes.InvalidArgument, "content[0] must be UserText") + return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") } chatLock := s.chatLocks.Get(chatId[:]) From 6db76d30b9cd021fb5dece8aa13441045bd30ace Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 21 May 2024 16:22:15 -0400 Subject: [PATCH 10/71] Rename struct --- pkg/code/server/grpc/chat/stream.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go index d29bf7c8..992970fb 100644 --- a/pkg/code/server/grpc/chat/stream.go +++ b/pkg/code/server/grpc/chat/stream.go @@ -93,7 +93,7 @@ func boundedStreamChatEventsRecv( } } -type chatIdWithEvent struct { +type chatEventNotification struct { chatId chat.ChatId owner *common.Account event *chatpb.ChatStreamEvent @@ -102,7 +102,7 @@ type chatIdWithEvent struct { func (s *server) asyncNotifyAll(chatId chat.ChatId, owner *common.Account, event *chatpb.ChatStreamEvent) error { m := proto.Clone(event).(*chatpb.ChatStreamEvent) - ok := s.chatEventChans.Send(chatId[:], &chatIdWithEvent{chatId, owner, m, time.Now()}) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, owner, m, time.Now()}) if !ok { return errors.New("chat event channel is full") } @@ -116,7 +116,7 @@ func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan inter }) for value := range channel { - typedValue, ok := value.(*chatIdWithEvent) + typedValue, ok := value.(*chatEventNotification) if !ok { log.Warn("channel did not receive expected struct") continue From 9bbf3ac10fdba46abf7c3e488e2e321966b4d50f Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 21 May 2024 16:36:49 -0400 Subject: [PATCH 11/71] Fill out missing log message --- pkg/code/server/grpc/chat/stream.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go index 992970fb..3f6ca6fa 100644 --- a/pkg/code/server/grpc/chat/stream.go +++ b/pkg/code/server/grpc/chat/stream.go @@ -125,7 +125,7 @@ func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan inter log := log.WithField("chat_id", typedValue.chatId.String()) if time.Since(typedValue.ts) > time.Second { - log.Warn("") + log.Warn("channel notification latency is elevated") } s.streamsMu.RLock() From 39b257e5be74950353d1ab68fd93045bf091ba34 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 22 May 2024 10:32:33 -0400 Subject: [PATCH 12/71] Add basic push support for user messages --- pkg/code/push/notifications.go | 11 +++++- pkg/code/server/grpc/chat/server.go | 47 ++++++++++++++++++++++-- pkg/code/server/grpc/chat/server_test.go | 3 +- 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index 42aff9be..2635dddd 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -364,8 +364,17 @@ func SendChatMessagePushNotification( }, }, } - case *chatpb.Content_NaclBox: + case *chatpb.Content_NaclBox, *chatpb.Content_Text: contentToPush = content + case *chatpb.Content_ThankYou: + contentToPush = &chatpb.Content{ + Type: &chatpb.Content_ServerLocalized{ + ServerLocalized: &chatpb.ServerLocalizedContent{ + // todo: localize this + KeyOrText: "🙏 They thanked you for their tip", + }, + }, + } } if contentToPush == nil { diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index 9614c124..d27673de 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -5,6 +5,7 @@ import ( "context" "fmt" "math" + "strings" "sync" "time" @@ -25,8 +26,10 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/chat" "github.com/code-payments/code-server/pkg/code/localization" + push_util "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" + push_lib "github.com/code-payments/code-server/pkg/push" sync_util "github.com/code-payments/code-server/pkg/sync" ) @@ -40,9 +43,10 @@ var ( // todo: Resolve duplication of streaming logic with messaging service. The latest and greatest will live here. type server struct { - log *logrus.Entry - data code_data.Provider - auth *auth_util.RPCSignatureVerifier + log *logrus.Entry + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + pusher push_lib.Provider streamsMu sync.RWMutex streams map[string]*chatEventStream @@ -53,11 +57,12 @@ type server struct { chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { s := &server{ log: logrus.StandardLogger().WithField("type", "chat/server"), data: data, auth: auth, + pusher: pusher, streams: make(map[string]*chatEventStream), chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters @@ -762,8 +767,42 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest log.WithError(err).Warn("failure notifying chat event") } + s.asyncPushChatMessage(ctx, owner, chatId, chatMessage) + return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, Message: chatMessage, }, nil } + +// todo: doesn't respect mute/unsubscribe rules +// todo: only sends pushes to active stream listeners instead of all message recipients +func (s *server) asyncPushChatMessage(ctx context.Context, sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + go func() { + s.streamsMu.RLock() + for key := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + receiver, err := common.NewAccountFromPublicKeyString(strings.Split(key, ":")[1]) + if err != nil { + continue + } + + if bytes.Equal(sender.PublicKey().ToBytes(), receiver.PublicKey().ToBytes()) { + continue + } + + go push_util.SendChatMessagePushNotification( + ctx, + s.data, + s.pusher, + "TontonTwitch", + receiver, + chatMessage, + ) + } + s.streamsMu.RUnlock() + }() +} diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/server_test.go index 627764b0..54af7c84 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/server_test.go @@ -29,6 +29,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/user/storage" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/kin" + memory_push "github.com/code-payments/code-server/pkg/push/memory" "github.com/code-payments/code-server/pkg/testutil" ) @@ -880,7 +881,7 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { data: code_data.NewTestDataProvider(), } - s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data)) + s := NewChatServer(env.data, auth_util.NewRPCSignatureVerifier(env.data), memory_push.NewPushProvider()) env.server = s.(*server) serv.RegisterService(func(server *grpc.Server) { From 540f868a167769b3ba5bddf7d07057f2317ed0cd Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 22 May 2024 10:39:59 -0400 Subject: [PATCH 13/71] Use separate context for pushing user chat messages --- pkg/code/server/grpc/chat/server.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/server.go index d27673de..b2b9c58f 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/server.go @@ -767,7 +767,7 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest log.WithError(err).Warn("failure notifying chat event") } - s.asyncPushChatMessage(ctx, owner, chatId, chatMessage) + s.asyncPushChatMessage(owner, chatId, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, @@ -777,7 +777,9 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest // todo: doesn't respect mute/unsubscribe rules // todo: only sends pushes to active stream listeners instead of all message recipients -func (s *server) asyncPushChatMessage(ctx context.Context, sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { +func (s *server) asyncPushChatMessage(sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + ctx := context.TODO() + go func() { s.streamsMu.RLock() for key := range s.streams { From 8185b9ede3cc26f210e0693e1cab17670e42adaf Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 10:12:26 -0400 Subject: [PATCH 14/71] Move existing chat stuff to v1 in prep for v2 --- pkg/code/async/geyser/external_deposit.go | 6 +- pkg/code/async/geyser/messenger.go | 6 +- pkg/code/chat/message_cash_transactions.go | 10 +-- pkg/code/chat/message_code_team.go | 4 +- pkg/code/chat/message_kin_purchases.go | 8 +- pkg/code/chat/message_merchant.go | 14 ++-- pkg/code/chat/message_tips.go | 6 +- pkg/code/chat/sender.go | 20 ++--- pkg/code/chat/sender_test.go | 54 ++++++------ pkg/code/data/chat/{ => v1}/memory/store.go | 2 +- .../data/chat/{ => v1}/memory/store_test.go | 2 +- pkg/code/data/chat/{ => v1}/model.go | 2 +- pkg/code/data/chat/{ => v1}/model_test.go | 2 +- pkg/code/data/chat/{ => v1}/postgres/model.go | 2 +- pkg/code/data/chat/{ => v1}/postgres/store.go | 2 +- .../data/chat/{ => v1}/postgres/store_test.go | 4 +- pkg/code/data/chat/{ => v1}/store.go | 2 +- pkg/code/data/chat/{ => v1}/tests/tests.go | 2 +- pkg/code/data/internal.go | 82 +++++++++---------- pkg/code/push/notifications.go | 10 +-- pkg/code/server/grpc/chat/{ => v1}/server.go | 30 +++---- .../server/grpc/chat/{ => v1}/server_test.go | 4 +- pkg/code/server/grpc/chat/{ => v1}/stream.go | 4 +- .../grpc/transaction/v2/history_test.go | 16 ++-- pkg/code/server/grpc/transaction/v2/swap.go | 6 +- .../server/grpc/transaction/v2/testutil.go | 4 +- 26 files changed, 152 insertions(+), 152 deletions(-) rename pkg/code/data/chat/{ => v1}/memory/store.go (99%) rename pkg/code/data/chat/{ => v1}/memory/store_test.go (74%) rename pkg/code/data/chat/{ => v1}/model.go (99%) rename pkg/code/data/chat/{ => v1}/model_test.go (97%) rename pkg/code/data/chat/{ => v1}/postgres/model.go (99%) rename pkg/code/data/chat/{ => v1}/postgres/store.go (98%) rename pkg/code/data/chat/{ => v1}/postgres/store_test.go (95%) rename pkg/code/data/chat/{ => v1}/store.go (99%) rename pkg/code/data/chat/{ => v1}/tests/tests.go (99%) rename pkg/code/server/grpc/chat/{ => v1}/server.go (96%) rename pkg/code/server/grpc/chat/{ => v1}/server_test.go (99%) rename pkg/code/server/grpc/chat/{ => v1}/stream.go (97%) diff --git a/pkg/code/async/geyser/external_deposit.go b/pkg/code/async/geyser/external_deposit.go index 1b5939ac..48685818 100644 --- a/pkg/code/async/geyser/external_deposit.go +++ b/pkg/code/async/geyser/external_deposit.go @@ -20,7 +20,7 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/deposit" "github.com/code-payments/code-server/pkg/code/data/fulfillment" "github.com/code-payments/code-server/pkg/code/data/intent" @@ -299,7 +299,7 @@ func processPotentialExternalDeposit(ctx context.Context, conf *conf, data code_ chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return errors.Wrap(err, "error sending chat message") } @@ -772,7 +772,7 @@ func delayedUsdcDepositProcessing( chatMessage, ) } - case chat.ErrMessageAlreadyExists: + case chat_v1.ErrMessageAlreadyExists: default: return } diff --git a/pkg/code/async/geyser/messenger.go b/pkg/code/async/geyser/messenger.go index 5c94153c..afb8e999 100644 --- a/pkg/code/async/geyser/messenger.go +++ b/pkg/code/async/geyser/messenger.go @@ -15,7 +15,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/code/thirdparty" "github.com/code-payments/code-server/pkg/database/query" @@ -169,13 +169,13 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi ctx, data, asciiBaseDomain, - chat.ChatTypeExternalApp, + chat_v1.ChatTypeExternalApp, true, recipientOwner, chatMessage, false, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/message_cash_transactions.go b/pkg/code/chat/message_cash_transactions.go index 1976d9d5..8a94361f 100644 --- a/pkg/code/chat/message_cash_transactions.go +++ b/pkg/code/chat/message_cash_transactions.go @@ -11,7 +11,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -93,9 +93,9 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro return errors.Wrap(err, "error getting original gift card issued intent") } - chatId := chat.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) + chatId := chat_v1.GetChatId(CashTransactionsName, giftCardIssuedIntentRecord.InitiatorOwnerAccount, true) - err = data.DeleteChatMessage(ctx, chatId, giftCardIssuedIntentRecord.IntentId) + err = data.DeleteChatMessageV1(ctx, chatId, giftCardIssuedIntentRecord.IntentId) if err != nil { return errors.Wrap(err, "error deleting chat message") } @@ -152,13 +152,13 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro ctx, data, CashTransactionsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, protoMessage, true, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return errors.Wrap(err, "error persisting chat message") } } diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index 536370a3..d24e2305 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -9,7 +9,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/localization" ) @@ -20,7 +20,7 @@ func SendCodeTeamMessage(ctx context.Context, data code_data.Provider, receiver ctx, data, CodeTeamName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index f9a2c6fd..4377c247 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -11,14 +11,14 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" ) // GetKinPurchasesChatId returns the chat ID for the Kin Purchases chat for a // given owner account -func GetKinPurchasesChatId(owner *common.Account) chat.ChatId { - return chat.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) +func GetKinPurchasesChatId(owner *common.Account) chat_v1.ChatId { + return chat_v1.GetChatId(KinPurchasesName, owner.PublicKey().ToBase58(), true) } // SendKinPurchasesMessage sends a message to the Kin Purchases chat. @@ -27,7 +27,7 @@ func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, recei ctx, data, KinPurchasesName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, chatMessage, diff --git a/pkg/code/chat/message_merchant.go b/pkg/code/chat/message_merchant.go index b4504c39..a340f8bd 100644 --- a/pkg/code/chat/message_merchant.go +++ b/pkg/code/chat/message_merchant.go @@ -13,7 +13,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -36,7 +36,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // the merchant. Representation in the UI may differ (ie. 2 and 3 are grouped), // but this is the most flexible solution with the chat model. chatTitle := PaymentsName - chatType := chat.ChatTypeInternal + chatType := chat_v1.ChatTypeInternal isVerifiedChat := false exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -59,7 +59,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i if paymentRequestRecord.Domain != nil { chatTitle = *paymentRequestRecord.Domain - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = paymentRequestRecord.IsVerified } @@ -87,7 +87,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -107,7 +107,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.SendPublicPaymentMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -126,7 +126,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i // and will have merchant payments appear in the verified merchant // chat. chatTitle = *destinationAccountInfoRecord.RelationshipTo - chatType = chat.ChatTypeExternalApp + chatType = chat_v1.ChatTypeExternalApp isVerifiedChat = true verbAndExchangeDataByMessageReceiver[intentRecord.ExternalDepositMetadata.DestinationOwnerAccount] = &verbAndExchangeData{ verb: chatpb.ExchangeDataContent_RECEIVED, @@ -171,7 +171,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i protoMessage, verbAndExchangeData.verb != chatpb.ExchangeDataContent_RECEIVED || !isVerifiedChat, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return nil, errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index b9984a9b..5752205a 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -9,7 +9,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/intent" ) @@ -70,13 +70,13 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten ctx, data, TipsName, - chat.ChatTypeInternal, + chat_v1.ChatTypeInternal, true, receiver, protoMessage, verb != chatpb.ExchangeDataContent_RECEIVED_TIP, ) - if err != nil && err != chat.ErrMessageAlreadyExists { + if err != nil && err != chat_v1.ErrMessageAlreadyExists { return nil, errors.Wrap(err, "error persisting chat message") } diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 41da0902..412e656b 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -12,7 +12,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" ) // SendChatMessage sends a chat message to a receiving owner account. @@ -24,13 +24,13 @@ func SendChatMessage( ctx context.Context, data code_data.Provider, chatTitle string, - chatType chat.ChatType, + chatType chat_v1.ChatType, isVerifiedChat bool, receiver *common.Account, protoMessage *chatpb.ChatMessage, isSilentMessage bool, ) (canPushMessage bool, err error) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) if protoMessage.Cursor != nil { // Let the utilities and GetMessages RPC handle cursors @@ -58,13 +58,13 @@ func SendChatMessage( canPersistMessage := true canPushMessage = !isSilentMessage - existingChatRecord, err := data.GetChatById(ctx, chatId) + existingChatRecord, err := data.GetChatByIdV1(ctx, chatId) switch err { case nil: canPersistMessage = !existingChatRecord.IsUnsubscribed canPushMessage = canPushMessage && canPersistMessage && !existingChatRecord.IsMuted - case chat.ErrChatNotFound: - chatRecord := &chat.Chat{ + case chat_v1.ErrChatNotFound: + chatRecord := &chat_v1.Chat{ ChatId: chatId, ChatType: chatType, ChatTitle: chatTitle, @@ -79,8 +79,8 @@ func SendChatMessage( CreatedAt: time.Now(), } - err = data.PutChat(ctx, chatRecord) - if err != nil && err != chat.ErrChatAlreadyExists { + err = data.PutChatV1(ctx, chatRecord) + if err != nil && err != chat_v1.ErrChatAlreadyExists { return false, err } default: @@ -88,7 +88,7 @@ func SendChatMessage( } if canPersistMessage { - messageRecord := &chat.Message{ + messageRecord := &chat_v1.Message{ ChatId: chatId, MessageId: base58.Encode(messageId), @@ -100,7 +100,7 @@ func SendChatMessage( Timestamp: ts.AsTime(), } - err = data.PutChatMessage(ctx, messageRecord) + err = data.PutChatMessageV1(ctx, messageRecord) if err != nil { return false, err } diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index 3438b3c8..ac767fd9 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -17,7 +17,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/badgecount" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/testutil" ) @@ -26,14 +26,14 @@ func TestSendChatMessage_HappyPath(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) var expectedBadgeCount int for i := 0; i < 10; i++ { chatMessage := newRandomChatMessage(t, i+1) expectedBadgeCount += 1 - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.True(t, canPush) @@ -56,7 +56,7 @@ func TestSendChatMessage_VerifiedChat(t *testing.T) { for _, isVerified := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, isVerified, receiver, chatMessage, true) + _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, isVerified, receiver, chatMessage, true) require.NoError(t, err) env.assertChatRecordSaved(t, chatTitle, receiver, isVerified) } @@ -67,11 +67,11 @@ func TestSendChatMessage_SilentMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for i, isSilent := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, isSilent) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, isSilent) require.NoError(t, err) assert.Equal(t, !isSilent, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, isSilent) @@ -84,7 +84,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isMuted := range []bool{false, true} { if isMuted { @@ -92,7 +92,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isMuted, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, false) @@ -105,7 +105,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) for _, isUnsubscribed := range []bool{false, true} { if isUnsubscribed { @@ -113,7 +113,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isUnsubscribed, canPush) if isUnsubscribed { @@ -130,12 +130,12 @@ func TestSendChatMessage_InvalidProtoMessage(t *testing.T) { chatTitle := CodeTeamName receiver := testutil.NewRandomAccount(t) - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), true) chatMessage := newRandomChatMessage(t, 1) chatMessage.Content = nil - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) assert.Error(t, err) assert.False(t, canPush) env.assertChatRecordNotSaved(t, chatId) @@ -173,13 +173,13 @@ func newRandomChatMessage(t *testing.T, contentLength int) *chatpb.ChatMessage { } func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver *common.Account, isVerified bool) { - chatId := chat.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) + chatId := chat_v1.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerified) - chatRecord, err := e.data.GetChatById(e.ctx, chatId) + chatRecord, err := e.data.GetChatByIdV1(e.ctx, chatId) require.NoError(t, err) assert.Equal(t, chatId[:], chatRecord.ChatId[:]) - assert.Equal(t, chat.ChatTypeInternal, chatRecord.ChatType) + assert.Equal(t, chat_v1.ChatTypeInternal, chatRecord.ChatType) assert.Equal(t, chatTitle, chatRecord.ChatTitle) assert.Equal(t, isVerified, chatRecord.IsVerified) assert.Equal(t, receiver.PublicKey().ToBase58(), chatRecord.CodeUser) @@ -188,8 +188,8 @@ func (e *testEnv) assertChatRecordSaved(t *testing.T, chatTitle string, receiver assert.False(t, chatRecord.IsUnsubscribed) } -func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { - messageRecord, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) +func (e *testEnv) assertChatMessageRecordSaved(t *testing.T, chatId chat_v1.ChatId, protoMessage *chatpb.ChatMessage, isSilent bool) { + messageRecord, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(protoMessage.GetMessageId().Value)) require.NoError(t, err) cloned := proto.Clone(protoMessage).(*chatpb.ChatMessage) @@ -218,22 +218,22 @@ func (e *testEnv) assertBadgeCount(t *testing.T, owner *common.Account, expected assert.EqualValues(t, expected, badgeCountRecord.BadgeCount) } -func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat.ChatId) { - _, err := e.data.GetChatById(e.ctx, chatId) - assert.Equal(t, chat.ErrChatNotFound, err) +func (e *testEnv) assertChatRecordNotSaved(t *testing.T, chatId chat_v1.ChatId) { + _, err := e.data.GetChatByIdV1(e.ctx, chatId) + assert.Equal(t, chat_v1.ErrChatNotFound, err) } -func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat.ChatId, messageId *chatpb.ChatMessageId) { - _, err := e.data.GetChatMessage(e.ctx, chatId, base58.Encode(messageId.Value)) - assert.Equal(t, chat.ErrMessageNotFound, err) +func (e *testEnv) assertChatMessageRecordNotSaved(t *testing.T, chatId chat_v1.ChatId, messageId *chatpb.ChatMessageId) { + _, err := e.data.GetChatMessageV1(e.ctx, chatId, base58.Encode(messageId.Value)) + assert.Equal(t, chat_v1.ErrMessageNotFound, err) } -func (e *testEnv) muteChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatMuteState(e.ctx, chatId, true)) +func (e *testEnv) muteChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatMuteStateV1(e.ctx, chatId, true)) } -func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat.ChatId) { - require.NoError(t, e.data.SetChatSubscriptionState(e.ctx, chatId, false)) +func (e *testEnv) unsubscribeFromChat(t *testing.T, chatId chat_v1.ChatId) { + require.NoError(t, e.data.SetChatSubscriptionStateV1(e.ctx, chatId, false)) } diff --git a/pkg/code/data/chat/memory/store.go b/pkg/code/data/chat/v1/memory/store.go similarity index 99% rename from pkg/code/data/chat/memory/store.go rename to pkg/code/data/chat/v1/memory/store.go index 1b869007..1d923071 100644 --- a/pkg/code/data/chat/memory/store.go +++ b/pkg/code/data/chat/v1/memory/store.go @@ -7,8 +7,8 @@ import ( "sync" "time" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type ChatsById []*chat.Chat diff --git a/pkg/code/data/chat/memory/store_test.go b/pkg/code/data/chat/v1/memory/store_test.go similarity index 74% rename from pkg/code/data/chat/memory/store_test.go rename to pkg/code/data/chat/v1/memory/store_test.go index 5d2c18a5..c27859e6 100644 --- a/pkg/code/data/chat/memory/store_test.go +++ b/pkg/code/data/chat/v1/memory/store_test.go @@ -3,7 +3,7 @@ package memory import ( "testing" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" ) func TestChatMemoryStore(t *testing.T) { diff --git a/pkg/code/data/chat/model.go b/pkg/code/data/chat/v1/model.go similarity index 99% rename from pkg/code/data/chat/model.go rename to pkg/code/data/chat/v1/model.go index d8fe7432..4d156996 100644 --- a/pkg/code/data/chat/model.go +++ b/pkg/code/data/chat/v1/model.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "bytes" diff --git a/pkg/code/data/chat/model_test.go b/pkg/code/data/chat/v1/model_test.go similarity index 97% rename from pkg/code/data/chat/model_test.go rename to pkg/code/data/chat/v1/model_test.go index 7774d286..062f372b 100644 --- a/pkg/code/data/chat/model_test.go +++ b/pkg/code/data/chat/v1/model_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "testing" diff --git a/pkg/code/data/chat/postgres/model.go b/pkg/code/data/chat/v1/postgres/model.go similarity index 99% rename from pkg/code/data/chat/postgres/model.go rename to pkg/code/data/chat/v1/postgres/model.go index 07158095..8987df2b 100644 --- a/pkg/code/data/chat/postgres/model.go +++ b/pkg/code/data/chat/v1/postgres/model.go @@ -8,10 +8,10 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" pgutil "github.com/code-payments/code-server/pkg/database/postgres" q "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) const ( diff --git a/pkg/code/data/chat/postgres/store.go b/pkg/code/data/chat/v1/postgres/store.go similarity index 98% rename from pkg/code/data/chat/postgres/store.go rename to pkg/code/data/chat/v1/postgres/store.go index 943a1935..bfb2b14f 100644 --- a/pkg/code/data/chat/postgres/store.go +++ b/pkg/code/data/chat/v1/postgres/store.go @@ -6,8 +6,8 @@ import ( "github.com/jmoiron/sqlx" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" - "github.com/code-payments/code-server/pkg/code/data/chat" ) type store struct { diff --git a/pkg/code/data/chat/postgres/store_test.go b/pkg/code/data/chat/v1/postgres/store_test.go similarity index 95% rename from pkg/code/data/chat/postgres/store_test.go rename to pkg/code/data/chat/v1/postgres/store_test.go index 49143ad7..4d72fc4e 100644 --- a/pkg/code/data/chat/postgres/store_test.go +++ b/pkg/code/data/chat/v1/postgres/store_test.go @@ -8,8 +8,8 @@ import ( "github.com/ory/dockertest/v3" "github.com/sirupsen/logrus" - "github.com/code-payments/code-server/pkg/code/data/chat" - "github.com/code-payments/code-server/pkg/code/data/chat/tests" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" + "github.com/code-payments/code-server/pkg/code/data/chat/v1/tests" postgrestest "github.com/code-payments/code-server/pkg/database/postgres/test" diff --git a/pkg/code/data/chat/store.go b/pkg/code/data/chat/v1/store.go similarity index 99% rename from pkg/code/data/chat/store.go rename to pkg/code/data/chat/v1/store.go index 2e79a228..c21471b5 100644 --- a/pkg/code/data/chat/store.go +++ b/pkg/code/data/chat/v1/store.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" diff --git a/pkg/code/data/chat/tests/tests.go b/pkg/code/data/chat/v1/tests/tests.go similarity index 99% rename from pkg/code/data/chat/tests/tests.go rename to pkg/code/data/chat/v1/tests/tests.go index f9eaaadf..1453c133 100644 --- a/pkg/code/data/chat/tests/tests.go +++ b/pkg/code/data/chat/v1/tests/tests.go @@ -10,9 +10,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/pointer" - "github.com/code-payments/code-server/pkg/code/data/chat" ) func RunTests(t *testing.T, s chat.Store, teardown func()) { diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 49479c84..95ce26f2 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -25,7 +25,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/airdrop" "github.com/code-payments/code-server/pkg/code/data/badgecount" "github.com/code-payments/code-server/pkg/code/data/balance" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/contact" "github.com/code-payments/code-server/pkg/code/data/currency" @@ -59,7 +59,7 @@ import ( airdrop_memory_client "github.com/code-payments/code-server/pkg/code/data/airdrop/memory" badgecount_memory_client "github.com/code-payments/code-server/pkg/code/data/badgecount/memory" balance_memory_client "github.com/code-payments/code-server/pkg/code/data/balance/memory" - chat_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/memory" + chat_v1_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/memory" commitment_memory_client "github.com/code-payments/code-server/pkg/code/data/commitment/memory" contact_memory_client "github.com/code-payments/code-server/pkg/code/data/contact/memory" currency_memory_client "github.com/code-payments/code-server/pkg/code/data/currency/memory" @@ -94,7 +94,7 @@ import ( airdrop_postgres_client "github.com/code-payments/code-server/pkg/code/data/airdrop/postgres" badgecount_postgres_client "github.com/code-payments/code-server/pkg/code/data/badgecount/postgres" balance_postgres_client "github.com/code-payments/code-server/pkg/code/data/balance/postgres" - chat_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/postgres" + chat_v1_postgres_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/postgres" commitment_postgres_client "github.com/code-payments/code-server/pkg/code/data/commitment/postgres" contact_postgres_client "github.com/code-payments/code-server/pkg/code/data/contact/postgres" currency_postgres_client "github.com/code-payments/code-server/pkg/code/data/currency/postgres" @@ -378,19 +378,19 @@ type DatabaseData interface { CountWebhookByState(ctx context.Context, state webhook.State) (uint64, error) GetAllPendingWebhooksReadyToSend(ctx context.Context, limit uint64) ([]*webhook.Record, error) - // Chat + // Chat V1 // -------------------------------------------------------------------------------- - PutChat(ctx context.Context, record *chat.Chat) error - GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) - GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) - PutChatMessage(ctx context.Context, record *chat.Message) error - DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error - GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) - GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) - AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error - GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) - SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error - SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error + PutChatV1(ctx context.Context, record *chat_v1.Chat) error + GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) + GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) + PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error + DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error + GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) + GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) + AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error + GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) + SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error + SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error // Badge Count // -------------------------------------------------------------------------------- @@ -470,7 +470,7 @@ type DatabaseProvider struct { paywall paywall.Store event event.Store webhook webhook.Store - chat chat.Store + chatv1 chat_v1.Store badgecount badgecount.Store login login.Store balance balance.Store @@ -532,7 +532,7 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { paywall: paywall_postgres_client.New(db), event: event_postgres_client.New(db), webhook: webhook_postgres_client.New(db), - chat: chat_postgres_client.New(db), + chatv1: chat_v1_postgres_client.New(db), badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), balance: balance_postgres_client.New(db), @@ -575,7 +575,7 @@ func NewTestDatabaseProvider() DatabaseData { paywall: paywall_memory_client.New(), event: event_memory_client.New(), webhook: webhook_memory_client.New(), - chat: chat_memory_client.New(), + chatv1: chat_v1_memory_client.New(), badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), balance: balance_memory_client.New(), @@ -1399,48 +1399,48 @@ func (dp *DatabaseProvider) GetAllPendingWebhooksReadyToSend(ctx context.Context return dp.webhook.GetAllPendingReadyToSend(ctx, limit) } -// Chat +// Chat V1 // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) PutChat(ctx context.Context, record *chat.Chat) error { - return dp.chat.PutChat(ctx, record) +func (dp *DatabaseProvider) PutChatV1(ctx context.Context, record *chat_v1.Chat) error { + return dp.chatv1.PutChat(ctx, record) } -func (dp *DatabaseProvider) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.Chat, error) { - return dp.chat.GetChatById(ctx, chatId) +func (dp *DatabaseProvider) GetChatByIdV1(ctx context.Context, chatId chat_v1.ChatId) (*chat_v1.Chat, error) { + return dp.chatv1.GetChatById(ctx, chatId) } -func (dp *DatabaseProvider) GetAllChatsForUser(ctx context.Context, user string, opts ...query.Option) ([]*chat.Chat, error) { +func (dp *DatabaseProvider) GetAllChatsForUserV1(ctx context.Context, user string, opts ...query.Option) ([]*chat_v1.Chat, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllChatsForUser(ctx, user, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) PutChatMessage(ctx context.Context, record *chat.Message) error { - return dp.chat.PutMessage(ctx, record) +func (dp *DatabaseProvider) PutChatMessageV1(ctx context.Context, record *chat_v1.Message) error { + return dp.chatv1.PutMessage(ctx, record) } -func (dp *DatabaseProvider) DeleteChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) error { - return dp.chat.DeleteMessage(ctx, chatId, messageId) +func (dp *DatabaseProvider) DeleteChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) error { + return dp.chatv1.DeleteMessage(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetChatMessage(ctx context.Context, chatId chat.ChatId, messageId string) (*chat.Message, error) { - return dp.chat.GetMessageById(ctx, chatId, messageId) +func (dp *DatabaseProvider) GetChatMessageV1(ctx context.Context, chatId chat_v1.ChatId, messageId string) (*chat_v1.Message, error) { + return dp.chatv1.GetMessageById(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetAllChatMessages(ctx context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.Message, error) { +func (dp *DatabaseProvider) GetAllChatMessagesV1(ctx context.Context, chatId chat_v1.ChatId, opts ...query.Option) ([]*chat_v1.Message, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chat.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv1.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) AdvanceChatPointer(ctx context.Context, chatId chat.ChatId, pointer string) error { - return dp.chat.AdvancePointer(ctx, chatId, pointer) +func (dp *DatabaseProvider) AdvanceChatPointerV1(ctx context.Context, chatId chat_v1.ChatId, pointer string) error { + return dp.chatv1.AdvancePointer(ctx, chatId, pointer) } -func (dp *DatabaseProvider) GetChatUnreadCount(ctx context.Context, chatId chat.ChatId) (uint32, error) { - return dp.chat.GetUnreadCount(ctx, chatId) +func (dp *DatabaseProvider) GetChatUnreadCountV1(ctx context.Context, chatId chat_v1.ChatId) (uint32, error) { + return dp.chatv1.GetUnreadCount(ctx, chatId) } -func (dp *DatabaseProvider) SetChatMuteState(ctx context.Context, chatId chat.ChatId, isMuted bool) error { - return dp.chat.SetMuteState(ctx, chatId, isMuted) +func (dp *DatabaseProvider) SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error { + return dp.chatv1.SetMuteState(ctx, chatId, isMuted) } -func (dp *DatabaseProvider) SetChatSubscriptionState(ctx context.Context, chatId chat.ChatId, isSubscribed bool) error { - return dp.chat.SetSubscriptionState(ctx, chatId, isSubscribed) +func (dp *DatabaseProvider) SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error { + return dp.chatv1.SetSubscriptionState(ctx, chatId, isSubscribed) } // Badge Count diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index 2635dddd..bfe4db7c 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -14,7 +14,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/code/thirdparty" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -59,13 +59,13 @@ func SendDepositPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") @@ -139,13 +139,13 @@ func SendGiftCardReturnedPushNotification( // Legacy push notification still considers chat mute state // // todo: Proper migration to chat system - chatRecord, err := data.GetChatById(ctx, chat.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) + chatRecord, err := data.GetChatByIdV1(ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, owner.PublicKey().ToBase58(), true)) switch err { case nil: if chatRecord.IsMuted { return nil } - case chat.ErrChatNotFound: + case chat_v1.ErrChatNotFound: default: log.WithError(err).Warn("failure getting chat record") return errors.Wrap(err, "error getting chat record") diff --git a/pkg/code/server/grpc/chat/server.go b/pkg/code/server/grpc/chat/v1/server.go similarity index 96% rename from pkg/code/server/grpc/chat/server.go rename to pkg/code/server/grpc/chat/v1/server.go index b2b9c58f..5fef09f9 100644 --- a/pkg/code/server/grpc/chat/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "bytes" @@ -24,7 +24,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" push_util "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/database/query" @@ -59,7 +59,7 @@ type server struct { func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { s := &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), + log: logrus.StandardLogger().WithField("type", "chat/v1/server"), data: data, auth: auth, pusher: pusher, @@ -119,7 +119,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } - chatRecords, err := s.data.GetAllChatsForUser( + chatRecords, err := s.data.GetAllChatsForUserV1( ctx, owner.PublicKey().ToBase58(), query.WithCursor(cursor), @@ -220,7 +220,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch if !skipUnreadCountQuery && !chatRecord.IsMuted && !chatRecord.IsUnsubscribed { // todo: will need batching when users have a large number of chats - unreadCount, err := s.data.GetChatUnreadCount(ctx, chatRecord.ChatId) + unreadCount, err := s.data.GetChatUnreadCountV1(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting unread count") } @@ -260,7 +260,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_NOT_FOUND, @@ -296,7 +296,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest cursor = req.Cursor.Value } - messageRecords, err := s.data.GetAllChatMessages( + messageRecords, err := s.data.GetAllChatMessagesV1( ctx, chatId, query.WithCursor(cursor), @@ -416,7 +416,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, @@ -430,7 +430,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.PermissionDenied, "") } - newPointerRecord, err := s.data.GetChatMessage(ctx, chatId, messageId) + newPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, messageId) if err == chat.ErrMessageNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, @@ -441,7 +441,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } if chatRecord.ReadPointer != nil { - oldPointerRecord, err := s.data.GetChatMessage(ctx, chatId, *chatRecord.ReadPointer) + oldPointerRecord, err := s.data.GetChatMessageV1(ctx, chatId, *chatRecord.ReadPointer) if err != nil { log.WithError(err).Warn("failure getting chat message record for old pointer value") return nil, status.Error(codes.Internal, "") @@ -454,7 +454,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } } - err = s.data.AdvanceChatPointer(ctx, chatId, messageId) + err = s.data.AdvanceChatPointerV1(ctx, chatId, messageId) if err != nil { log.WithError(err).Warn("failure advancing pointer") return nil, status.Error(codes.Internal, "") @@ -484,7 +484,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, @@ -511,7 +511,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque }, nil } - err = s.data.SetChatMuteState(ctx, chatId, req.IsMuted) + err = s.data.SetChatMuteStateV1(ctx, chatId, req.IsMuted) if err != nil { log.WithError(err).Warn("failure setting mute status") return nil, status.Error(codes.Internal, "") @@ -542,7 +542,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr return nil, err } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.SetSubscriptionStateResponse{ Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, @@ -569,7 +569,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } - err = s.data.SetChatSubscriptionState(ctx, chatId, req.IsSubscribed) + err = s.data.SetChatSubscriptionStateV1(ctx, chatId, req.IsSubscribed) if err != nil { log.WithError(err).Warn("failure setting subcription status") return nil, status.Error(codes.Internal, "") diff --git a/pkg/code/server/grpc/chat/server_test.go b/pkg/code/server/grpc/chat/v1/server_test.go similarity index 99% rename from pkg/code/server/grpc/chat/server_test.go rename to pkg/code/server/grpc/chat/v1/server_test.go index 54af7c84..0a9abad4 100644 --- a/pkg/code/server/grpc/chat/server_test.go +++ b/pkg/code/server/grpc/chat/v1/server_test.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" @@ -22,7 +22,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/phone" "github.com/code-payments/code-server/pkg/code/data/preferences" "github.com/code-payments/code-server/pkg/code/data/user" diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/v1/stream.go similarity index 97% rename from pkg/code/server/grpc/chat/stream.go rename to pkg/code/server/grpc/chat/v1/stream.go index 3f6ca6fa..05b63235 100644 --- a/pkg/code/server/grpc/chat/stream.go +++ b/pkg/code/server/grpc/chat/v1/stream.go @@ -1,4 +1,4 @@ -package chat +package chat_v1 import ( "context" @@ -14,7 +14,7 @@ import ( chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" "github.com/code-payments/code-server/pkg/code/common" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v1" ) const ( diff --git a/pkg/code/server/grpc/transaction/v2/history_test.go b/pkg/code/server/grpc/transaction/v2/history_test.go index f2cddee5..80ae0442 100644 --- a/pkg/code/server/grpc/transaction/v2/history_test.go +++ b/pkg/code/server/grpc/transaction/v2/history_test.go @@ -12,7 +12,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" @@ -142,7 +142,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { sendingPhone.tip456KinToCodeUser(t, receivingPhone, twitterUsername).requireSuccess(t) - chatMessageRecords, err := server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err := server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 10) @@ -236,7 +236,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 32.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(321), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 3) @@ -267,7 +267,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 123.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(123), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.CashTransactionsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 7) @@ -334,7 +334,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 2.1, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(42), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -347,7 +347,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 45.6, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(456), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 5) @@ -396,7 +396,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 12_345.0, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(12_345), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId("example.com", receivingPhone.parentAccount.PublicKey().ToBase58(), false)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) @@ -409,7 +409,7 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 77.69, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.EqualValues(t, 77690000, protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) - chatMessageRecords, err = server.data.GetAllChatMessages(server.ctx, chat.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + chatMessageRecords, err = server.data.GetAllChatMessagesV1(server.ctx, chat_v1.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) require.NoError(t, err) require.Len(t, chatMessageRecords, 1) diff --git a/pkg/code/server/grpc/transaction/v2/swap.go b/pkg/code/server/grpc/transaction/v2/swap.go index 4f05ab2c..8bb6e230 100644 --- a/pkg/code/server/grpc/transaction/v2/swap.go +++ b/pkg/code/server/grpc/transaction/v2/swap.go @@ -22,7 +22,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" "github.com/code-payments/code-server/pkg/code/data/account" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/localization" push_util "github.com/code-payments/code-server/pkg/code/push" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -511,7 +511,7 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con // Inspect the chat history for a USDC deposited message. If that message // doesn't exist, then avoid sending the swap in progress chat message, since // it can lead to user confusion. - chatMessageRecords, err := s.data.GetAllChatMessages(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) + chatMessageRecords, err := s.data.GetAllChatMessagesV1(ctx, chatId, query.WithDirection(query.Descending), query.WithLimit(1)) switch err { case nil: var protoChatMessage chatpb.ChatMessage @@ -526,7 +526,7 @@ func (s *transactionServer) bestEffortNotifyUserOfSwapInProgress(ctx context.Con return nil } } - case chat.ErrMessageNotFound: + case chat_v1.ErrMessageNotFound: default: return errors.Wrap(err, "error fetching chat messages") } diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 38c42626..1da54d65 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -35,7 +35,7 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" - "github.com/code-payments/code-server/pkg/code/data/chat" + chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/currency" "github.com/code-payments/code-server/pkg/code/data/deposit" @@ -6173,7 +6173,7 @@ func isSubmitIntentError(resp *transactionpb.SubmitIntentResponse, err error) bo return err != nil || resp.GetError() != nil } -func getProtoChatMessage(t *testing.T, record *chat.Message) *chatpb.ChatMessage { +func getProtoChatMessage(t *testing.T, record *chat_v1.Message) *chatpb.ChatMessage { var protoMessage chatpb.ChatMessage require.NoError(t, proto.Unmarshal(record.Data, &protoMessage)) return &protoMessage From 2ef1e0249f6f699bcc9ecf354b9cec7c9b2b3298 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 10:55:49 -0400 Subject: [PATCH 15/71] Add skeleton for v2 gRPC chat service --- pkg/code/server/grpc/chat/v2/server.go | 190 ++++++++++++++++++++ pkg/code/server/grpc/chat/v2/server_test.go | 1 + pkg/code/server/grpc/chat/v2/stream.go | 32 ++++ 3 files changed, 223 insertions(+) create mode 100644 pkg/code/server/grpc/chat/v2/server.go create mode 100644 pkg/code/server/grpc/chat/v2/server_test.go create mode 100644 pkg/code/server/grpc/chat/v2/stream.go diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go new file mode 100644 index 00000000..37319f40 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -0,0 +1,190 @@ +package chat_v2 + +import ( + "context" + "time" + + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + + auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/common" + code_data "github.com/code-payments/code-server/pkg/code/data" + "github.com/code-payments/code-server/pkg/grpc/client" +) + +// todo: Ensure all relevant logging fields are set +type server struct { + log *logrus.Entry + + data code_data.Provider + auth *auth_util.RPCSignatureVerifier + + chatpb.UnimplementedChatServer +} + +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { + return &server{ + log: logrus.StandardLogger().WithField("type", "chat/v2/server"), + data: data, + auth: auth, + } +} + +func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { + log := s.log.WithField("method", "GetChats") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} + +func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { + log := s.log.WithField("method", "GetMessages") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} + +func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { + ctx := streamer.Context() + + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) + if err != nil { + return err + } + + if req.GetOpenStream() == nil { + return status.Error(codes.InvalidArgument, "open_stream is nil") + } + + if req.GetOpenStream().Signature == nil { + return status.Error(codes.InvalidArgument, "signature is nil") + } + + owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.GetOpenStream().Signature + req.GetOpenStream().Signature = nil + if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err + } + + return status.Error(codes.Unimplemented, "") +} + +func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} + +func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { + log := s.log.WithField("method", "AdvancePointer") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} + +func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { + log := s.log.WithField("method", "SetMuteState") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} + +func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { + log := s.log.WithField("method", "SetSubscriptionState") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + return nil, status.Error(codes.Unimplemented, "") +} diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go new file mode 100644 index 00000000..aacc4f95 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -0,0 +1 @@ +package chat_v2 diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go new file mode 100644 index 00000000..bf986572 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -0,0 +1,32 @@ +package chat_v2 + +import ( + "context" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} From 23a9aa09e653aba9a29bbed81779c71e3ff668e9 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 12:16:22 -0400 Subject: [PATCH 16/71] Define chat v2 models --- pkg/code/data/chat/v2/id.go | 182 +++++++++++++++ pkg/code/data/chat/v2/id_test.go | 55 +++++ pkg/code/data/chat/v2/model.go | 370 +++++++++++++++++++++++++++++++ pkg/code/data/chat/v2/store.go | 9 + 4 files changed, 616 insertions(+) create mode 100644 pkg/code/data/chat/v2/id.go create mode 100644 pkg/code/data/chat/v2/id_test.go create mode 100644 pkg/code/data/chat/v2/model.go create mode 100644 pkg/code/data/chat/v2/store.go diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go new file mode 100644 index 00000000..267ae7f2 --- /dev/null +++ b/pkg/code/data/chat/v2/id.go @@ -0,0 +1,182 @@ +package chat_v2 + +import ( + "bytes" + "encoding/hex" + "time" + + "github.com/google/uuid" + "github.com/pkg/errors" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +type ChatId [32]byte + +// GetChatIdFromProto gets a chat ID from the protobuf variant +func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { + if err := proto.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "proto validation failed") + } + + var typed ChatId + copy(typed[:], proto.Value) + + if err := typed.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "invalid chat id") + } + + return typed, nil +} + +// Validate validates a chat ID +func (c ChatId) Validate() error { + return nil +} + +// String returns the string representation of a ChatId +func (c ChatId) String() string { + return hex.EncodeToString(c[:]) +} + +// Random UUIDv4 ID for chat members +type MemberId uuid.UUID + +// GetMemberIdFromProto gets a member ID from the protobuf variant +func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { + if err := proto.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "proto validation failed") + } + + var typed MemberId + copy(typed[:], proto.Value) + + if err := typed.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "invalid member id") + } + + return typed, nil +} + +// GenerateMemberId generates a new random chat member ID +func GenerateMemberId() MemberId { + return MemberId(uuid.New()) +} + +// Validate validates a chat member ID +func (m MemberId) Validate() error { + casted := uuid.UUID(m) + + if casted.Version() != 4 { + return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + } + + return nil +} + +// String returns the string representation of a MemberId +func (m MemberId) String() string { + return uuid.UUID(m).String() +} + +// Time-based UUIDv7 ID for chat messages +type MessageId uuid.UUID + +// GenerateMessageId generates a UUIDv7 message ID using the current time +func GenerateMessageId() MessageId { + return GenerateMessageIdAtTime(time.Now()) +} + +// GenerateMessageIdAtTime generates a UUIDv7 message ID using the provided timestamp +func GenerateMessageIdAtTime(ts time.Time) MessageId { + // Convert timestamp to milliseconds since Unix epoch + millis := ts.UnixNano() / int64(time.Millisecond) + + // Create a byte slice to hold the UUID + var uuidBytes [16]byte + + // Populate the first 6 bytes with the timestamp (42 bits for timestamp) + uuidBytes[0] = byte((millis >> 40) & 0xff) + uuidBytes[1] = byte((millis >> 32) & 0xff) + uuidBytes[2] = byte((millis >> 24) & 0xff) + uuidBytes[3] = byte((millis >> 16) & 0xff) + uuidBytes[4] = byte((millis >> 8) & 0xff) + uuidBytes[5] = byte(millis & 0xff) + + // Set the version to 7 (UUIDv7) + uuidBytes[6] = (uuidBytes[6] & 0x0f) | (0x7 << 4) + + // Populate the remaining bytes with random values + randomUUID := uuid.New() + copy(uuidBytes[7:], randomUUID[7:]) + + return MessageId(uuidBytes) +} + +// GetMessageIdFromProto gets a message ID from the protobuf variant +func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { + if err := proto.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "proto validation failed") + } + + var typed MessageId + copy(typed[:], proto.Value) + + if err := typed.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "invalid message id") + } + + return typed, nil +} + +// GetTimestamp gets the encoded timestamp in the message ID +func (m MessageId) GetTimestamp() (time.Time, error) { + if err := m.Validate(); err != nil { + return time.Time{}, errors.Wrap(err, "invalid message id") + } + + // Extract the first 6 bytes as the timestamp + millis := (int64(m[0]) << 40) | (int64(m[1]) << 32) | (int64(m[2]) << 24) | + (int64(m[3]) << 16) | (int64(m[4]) << 8) | int64(m[5]) + + // Convert milliseconds since Unix epoch to time.Time + timestamp := time.Unix(0, millis*int64(time.Millisecond)) + + return timestamp, nil +} + +// Equals returns whether two message IDs are equal +func (m MessageId) Equals(other MessageId) bool { + return m.Compare(other) == 0 +} + +// Before returns whether the message ID is before the provided value +func (m MessageId) Before(other MessageId) bool { + return m.Compare(other) < 0 +} + +// Before returns whether the message ID is after the provided value +func (m MessageId) After(other MessageId) bool { + return m.Compare(other) > 0 +} + +// Compare returns the byte comparison of the message ID +func (m MessageId) Compare(other MessageId) int { + return bytes.Compare(m[:], other[:]) +} + +// Validate validates a message ID +func (m MessageId) Validate() error { + casted := uuid.UUID(m) + + if casted.Version() != 7 { + return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + } + + return nil +} + +// String returns the string representation of a MessageId +func (m MessageId) String() string { + return uuid.UUID(m).String() +} diff --git a/pkg/code/data/chat/v2/id_test.go b/pkg/code/data/chat/v2/id_test.go new file mode 100644 index 00000000..27f1f359 --- /dev/null +++ b/pkg/code/data/chat/v2/id_test.go @@ -0,0 +1,55 @@ +package chat_v2 + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGenerateMemberId_Validation(t *testing.T) { + valid := GenerateMemberId() + assert.NoError(t, valid.Validate()) + + invalid := MemberId(GenerateMessageId()) + assert.Error(t, invalid.Validate()) +} + +func TestGenerateMessageId_Validation(t *testing.T) { + valid := GenerateMessageId() + assert.NoError(t, valid.Validate()) + + invalid := MessageId(uuid.New()) + assert.Error(t, invalid.Validate()) +} + +func TestGenerateMessageId_TimestampExtraction(t *testing.T) { + expectedTs := time.Now() + + messageId := GenerateMessageIdAtTime(expectedTs) + actualTs, err := messageId.GetTimestamp() + require.NoError(t, err) + assert.Equal(t, expectedTs.UnixMilli(), actualTs.UnixMilli()) +} + +func TestGenerateMessageId_Ordering(t *testing.T) { + now := time.Now() + messageIds := make([]MessageId, 0) + for i := 0; i < 10; i++ { + messageId := GenerateMessageIdAtTime(now.Add(time.Duration(i * int(time.Millisecond)))) + messageIds = append(messageIds, messageId) + } + + for i := 0; i < len(messageIds)-1; i++ { + assert.True(t, messageIds[i].Equals(messageIds[i])) + assert.False(t, messageIds[i].Equals(messageIds[i+1])) + + assert.True(t, messageIds[i].Before(messageIds[i+1])) + assert.False(t, messageIds[i].After(messageIds[i+1])) + + assert.False(t, messageIds[i+1].Before(messageIds[i])) + assert.True(t, messageIds[i+1].After(messageIds[i])) + } +} diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go new file mode 100644 index 00000000..d68a76eb --- /dev/null +++ b/pkg/code/data/chat/v2/model.go @@ -0,0 +1,370 @@ +package chat_v2 + +import ( + "time" + + "github.com/mr-tron/base58" + "github.com/pkg/errors" + + "github.com/code-payments/code-server/pkg/pointer" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" +) + +type ChatType int + +const ( + ChatTypeUnknown ChatType = iota + ChatTypeNotification + ChatTypeTwoWay + // ChatTypeGroup +) + +type ReferenceType int + +const ( + ReferenceTypeUnknown ReferenceType = iota + ReferenceTypeIntent + ReferenceTypeSignature +) + +type PointerType int + +const ( + PointerTypeUnknown PointerType = iota + PointerTypeSent + PointerTypeDelivered + PointerTypeRead +) + +type Platform int + +const ( + PlatformUnknown Platform = iota + PlatformCode + PlatformTwitter +) + +type ChatRecord struct { + Id int64 + ChatId ChatId + + ChatType ChatType + + // Presence determined by ChatType: + // * Notification: Present, and may be a localization key + // * Two Way: Not present and generated dynamically based on chat members + // * Group: Present, and will not be a localization key + ChatTitle *string + + IsVerified bool + + CreatedAt time.Time +} + +type MemberRecord struct { + Id int64 + ChatId ChatId + MemberId MemberId + + Platform Platform + PlatformId string + + DeliveryPointer *MessageId + ReadPointer *MessageId + + IsMuted bool + IsUnsubscribed bool + + JoinedAt time.Time +} + +type MessageRecord struct { + Id int64 + ChatId ChatId + MessageId MessageId + + // Not present for notification-style chats + Sender *MemberId + + Data []byte + + ReferenceType *ReferenceType + Reference *string + + IsSilent bool + + // Note: No timestamp field, since it's encoded in MessageId +} + +// GetChatIdFromProto gets a chat ID from the protobuf variant +func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { + switch proto { + case chatpb.Pointer_SENT: + return PointerTypeSent + case chatpb.Pointer_DELIVERED: + return PointerTypeDelivered + case chatpb.Pointer_READ: + return PointerTypeRead + default: + return PointerTypeUnknown + } +} + +// String returns the string representation of the pointer type +func (p PointerType) String() string { + switch p { + case PointerTypeSent: + return "sent" + case PointerTypeDelivered: + return "delivered" + case PointerTypeRead: + return "read" + default: + return "unknown" + } +} + +// Validate validates a chat Record +func (r *ChatRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + switch r.ChatType { + case ChatTypeNotification: + if r.ChatTitle == nil || len(*r.ChatTitle) == 0 { + return errors.New("chat title is required for notification chats") + } + case ChatTypeTwoWay: + if r.ChatTitle != nil { + return errors.New("chat title cannot be set for two way chats") + } + default: + return errors.Errorf("invalid chat type: %d", r.ChatType) + } + + if r.CreatedAt.IsZero() { + return errors.New("creation timestamp is required") + } + + return nil +} + +// Clone clones a chat record +func (r *ChatRecord) Clone() ChatRecord { + return ChatRecord{ + Id: r.Id, + ChatId: r.ChatId, + + ChatType: r.ChatType, + + ChatTitle: pointer.StringCopy(r.ChatTitle), + + IsVerified: r.IsVerified, + + CreatedAt: r.CreatedAt, + } +} + +// CopyTo copies a chat record to the provided destination +func (r *ChatRecord) CopyTo(dst *ChatRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + + dst.ChatType = r.ChatType + + dst.ChatTitle = pointer.StringCopy(r.ChatTitle) + + dst.IsVerified = r.IsVerified + + dst.CreatedAt = r.CreatedAt +} + +// Validate validates a member Record +func (r *MemberRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + if err := r.MemberId.Validate(); err != nil { + return errors.Wrap(err, "invalid member id") + } + + if len(r.PlatformId) == 0 { + return errors.New("platform id is required") + } + + switch r.Platform { + case PlatformCode: + decoded, err := base58.Decode(r.PlatformId) + if err != nil { + return errors.Wrap(err, "invalid base58 plaftorm id") + } + + if len(decoded) != 32 { + return errors.Wrap(err, "platform id is not a 32 byte buffer") + } + case PlatformTwitter: + if len(r.PlatformId) > 15 { + return errors.New("platform id must have at most 15 characters") + } + default: + return errors.Errorf("invalid plaftorm: %d", r.Platform) + } + + if r.DeliveryPointer != nil { + if err := r.DeliveryPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid delivery pointer") + } + } + + if r.ReadPointer != nil { + if err := r.ReadPointer.Validate(); err != nil { + return errors.Wrap(err, "invalid read pointer") + } + } + + if r.JoinedAt.IsZero() { + return errors.New("joined timestamp is required") + } + + return nil +} + +// Clone clones a member record +func (r *MemberRecord) Clone() MemberRecord { + return MemberRecord{ + Id: r.Id, + ChatId: r.ChatId, + MemberId: r.MemberId, + + Platform: r.Platform, + PlatformId: r.PlatformId, + + DeliveryPointer: r.DeliveryPointer, // todo: pointer copy safety + ReadPointer: r.ReadPointer, // todo: pointer copy safety + + IsMuted: r.IsMuted, + IsUnsubscribed: r.IsUnsubscribed, + + JoinedAt: r.JoinedAt, + } +} + +// CopyTo copies a member record to the provided destination +func (r *MemberRecord) CopyTo(dst *MemberRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.MemberId = r.MemberId + + dst.Platform = r.Platform + dst.PlatformId = r.PlatformId + + dst.DeliveryPointer = r.DeliveryPointer // todo: pointer copy safety + dst.ReadPointer = r.ReadPointer // todo: pointer copy safety + + dst.IsMuted = r.IsMuted + dst.IsUnsubscribed = r.IsUnsubscribed + + dst.JoinedAt = r.JoinedAt +} + +// Validate validates a message Record +func (r *MessageRecord) Validate() error { + if err := r.ChatId.Validate(); err != nil { + return errors.Wrap(err, "invalid chat id") + } + + if err := r.MessageId.Validate(); err != nil { + return errors.Wrap(err, "invalid message id") + } + + if r.Sender != nil { + if err := r.Sender.Validate(); err != nil { + return errors.Wrap(err, "invalid sender id") + } + } + + if len(r.Data) == 0 { + return errors.New("message data is required") + } + + if r.Reference == nil && r.ReferenceType != nil { + return errors.New("reference is required when reference type is provided") + } + + if r.Reference != nil && r.ReferenceType == nil { + return errors.New("reference cannot be set when reference type is missing") + } + + if r.ReferenceType != nil { + switch *r.ReferenceType { + case ReferenceTypeIntent: + decoded, err := base58.Decode(*r.Reference) + if err != nil { + return errors.Wrap(err, "invalid base58 intent id reference") + } + + if len(decoded) != 32 { + return errors.Wrap(err, "reference is not a 32 byte buffer") + } + case ReferenceTypeSignature: + decoded, err := base58.Decode(*r.Reference) + if err != nil { + return errors.Wrap(err, "invalid base58 signature reference") + } + + if len(decoded) != 64 { + return errors.Wrap(err, "reference is not a 64 byte buffer") + } + default: + return errors.Errorf("invalid reference type: %d", *r.ReferenceType) + } + } + + return nil +} + +// Clone clones a message record +func (r *MessageRecord) Clone() MessageRecord { + return MessageRecord{ + Id: r.Id, + ChatId: r.ChatId, + MessageId: r.MessageId, + + Sender: r.Sender, // todo: pointer copy safety + + Data: r.Data, // todo: pointer copy safety + + ReferenceType: r.ReferenceType, // todo: pointer copy safety + Reference: r.Reference, // todo: pointer copy safety + + IsSilent: r.IsSilent, + + // todo: finish implementing me + } +} + +// CopyTo copies a message record to the provided destination +func (r *MessageRecord) CopyTo(dst *MessageRecord) { + dst.Id = r.Id + dst.ChatId = r.ChatId + dst.MessageId = r.MessageId + + dst.Sender = r.Sender // todo: pointer copy safety + + dst.Data = r.Data // todo: pointer copy safety + + dst.ReferenceType = r.ReferenceType // todo: pointer copy safety + dst.Reference = r.Reference // todo: pointer copy safety + + dst.IsSilent = r.IsSilent + + // todo: finish implementing me +} + +// GetTimestamp gets the timestamp for a message record +func (r *MessageRecord) GetTimestamp() (time.Time, error) { + return r.MessageId.GetTimestamp() +} diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go new file mode 100644 index 00000000..a4e9f858 --- /dev/null +++ b/pkg/code/data/chat/v2/store.go @@ -0,0 +1,9 @@ +package chat_v2 + +var ( +// todo: Define well-known errors here +) + +// todo: Define interface methods +type Store interface { +} From 3035a8f7c724510f2e35a8d7031a89375d515bb7 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 13:36:25 -0400 Subject: [PATCH 17/71] Implement RPCs that operate on chat member state --- pkg/code/data/chat/v2/store.go | 26 +++- pkg/code/data/internal.go | 34 +++++ pkg/code/server/grpc/chat/v2/server.go | 202 ++++++++++++++++++++++++- 3 files changed, 258 insertions(+), 4 deletions(-) diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index a4e9f858..b1d8d089 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -1,9 +1,33 @@ package chat_v2 +import ( + "context" + "errors" +) + var ( -// todo: Define well-known errors here + ErrChatNotFound = errors.New("chat not found") + ErrMemberNotFound = errors.New("chat member not found") + ErrMessageNotFound = errors.New("chat message not found") ) // todo: Define interface methods type Store interface { + // GetChatById gets a chat by its chat ID + GetChatById(ctx context.Context, chatId ChatId) (*ChatRecord, error) + + // GetMemberById gets a chat member by the chat and member IDs + GetMemberById(ctx context.Context, chatId ChatId, memberId MemberId) (*MemberRecord, error) + + // GetMessageById gets a chat message by the chat and message IDs + GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) + + // AdvancePointer advances a chat pointer for a chat member + AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) error + + // SetMuteState updates the mute state for a chat member + SetMuteState(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error + + // SetSubscriptionState updates the subscription state for a chat member + SetSubscriptionState(ctx context.Context, chatId ChatId, memberId MemberId, isSubscribed bool) error } diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 95ce26f2..20a72601 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -26,6 +26,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/badgecount" "github.com/code-payments/code-server/pkg/code/data/balance" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/contact" "github.com/code-payments/code-server/pkg/code/data/currency" @@ -392,6 +393,15 @@ type DatabaseData interface { SetChatMuteStateV1(ctx context.Context, chatId chat_v1.ChatId, isMuted bool) error SetChatSubscriptionStateV1(ctx context.Context, chatId chat_v1.ChatId, isSubscribed bool) error + // Chat V2 + // -------------------------------------------------------------------------------- + GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) + GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) + GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error + SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error + SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error + // Badge Count // -------------------------------------------------------------------------------- AddToBadgeCount(ctx context.Context, owner string, amount uint32) error @@ -471,6 +481,7 @@ type DatabaseProvider struct { event event.Store webhook webhook.Store chatv1 chat_v1.Store + chatv2 chat_v2.Store badgecount badgecount.Store login login.Store balance balance.Store @@ -533,6 +544,7 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { event: event_postgres_client.New(db), webhook: webhook_postgres_client.New(db), chatv1: chat_v1_postgres_client.New(db), + chatv2: nil, // todo: Initialize me badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), balance: balance_postgres_client.New(db), @@ -576,6 +588,7 @@ func NewTestDatabaseProvider() DatabaseData { event: event_memory_client.New(), webhook: webhook_memory_client.New(), chatv1: chat_v1_memory_client.New(), + chatv2: nil, // todo: initialize me badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), balance: balance_memory_client.New(), @@ -1443,6 +1456,27 @@ func (dp *DatabaseProvider) SetChatSubscriptionStateV1(ctx context.Context, chat return dp.chatv1.SetSubscriptionState(ctx, chatId, isSubscribed) } +// Chat V2 +// -------------------------------------------------------------------------------- +func (dp *DatabaseProvider) GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) { + return dp.chatv2.GetChatById(ctx, chatId) +} +func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) { + return dp.chatv2.GetMemberById(ctx, chatId, memberId) +} +func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { + return dp.chatv2.GetMessageById(ctx, chatId, messageId) +} +func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error { + return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) +} +func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { + return dp.chatv2.SetMuteState(ctx, chatId, memberId, isMuted) +} +func (dp *DatabaseProvider) SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error { + return dp.chatv2.SetSubscriptionState(ctx, chatId, memberId, isSubscribed) +} + // Badge Count // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) AddToBadgeCount(ctx context.Context, owner string, amount uint32) error { diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 37319f40..164bde2b 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -4,6 +4,7 @@ import ( "context" "time" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -13,7 +14,10 @@ import ( auth_util "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/twitter" "github.com/code-payments/code-server/pkg/grpc/client" + timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" ) // todo: Ensure all relevant logging fields are set @@ -140,13 +144,83 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.Pointer.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + pointerType := chat.GetPointerTypeFromProto(req.Pointer.Kind) + log = log.WithField("pointer_type", pointerType.String()) + switch pointerType { + case chat.PointerTypeDelivered, chat.PointerTypeRead: + default: + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + pointerValue, err := chat.GetMessageIdFromProto(req.Pointer.Value) + if err != nil { + log.WithError(err).Warn("invalid pointer value") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("pointer_value", pointerValue.String()) + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - return nil, status.Error(codes.Unimplemented, "") + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + isChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !isChatMember { + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + _, err = s.data.GetChatMessageByIdV2(ctx, chatId, pointerValue) + switch err { + case nil: + case chat.ErrMessageNotFound: + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat message record") + return nil, status.Error(codes.Internal, "") + } + + // Note: Guarantees that pointer will never be advanced to some point in the past + err = s.data.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointerValue) + if err != nil { + log.WithError(err).Warn("failure advancing chat pointer") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_OK, + }, nil } func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { @@ -160,13 +234,56 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - return nil, status.Error(codes.Unimplemented, "") + // todo: Use chat record to determine if muting is allowed + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + isChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !isChatMember { + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) + if err != nil { + log.WithError(err).Warn("failure setting mute state") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_OK, + }, nil } func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { @@ -180,11 +297,90 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - return nil, status.Error(codes.Unimplemented, "") + // todo: Use chat record to determine if muting is allowed + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + err = s.data.SetChatSubscriptionStateV2(ctx, chatId, memberId, req.IsSubscribed) + if err != nil { + log.WithError(err).Warn("failure setting mute state") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_OK, + }, nil +} + +func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { + memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) + switch err { + case nil: + case chat.ErrMemberNotFound: + return false, nil + default: + return false, errors.Wrap(err, "error getting member record") + } + + switch memberRecord.Platform { + case chat.PlatformCode: + return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil + case chat.PlatformTwitter: + // todo: This logic should live elsewhere in somewhere more common + + ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) + if err != nil { + return false, errors.Wrap(err, "error deriving twitter tip address") + } + + twitterRecord, err := s.data.GetTwitterUserByUsername(ctx, memberRecord.PlatformId) + switch err { + case nil: + case twitter.ErrUserNotFound: + return false, nil + default: + return false, errors.Wrap(err, "error getting twitter user") + } + + return twitterRecord.TipAddress == ownerTipAccount.PublicKey().ToBase58(), nil + default: + return false, nil + } } From 577f4cc598f6045f0ce96d27638898c571070ea5 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 13:49:33 -0400 Subject: [PATCH 18/71] Implement StreamChatEvents with integration of pointer events --- pkg/code/server/grpc/chat/v2/server.go | 163 +++++++++++++++++++++++-- pkg/code/server/grpc/chat/v2/stream.go | 144 ++++++++++++++++++++++ 2 files changed, 298 insertions(+), 9 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 164bde2b..a8f0592a 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -2,14 +2,19 @@ package chat_v2 import ( "context" + "fmt" + "sync" "time" "github.com/pkg/errors" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" auth_util "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" @@ -18,6 +23,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/twitter" "github.com/code-payments/code-server/pkg/grpc/client" timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" + sync_util "github.com/code-payments/code-server/pkg/sync" ) // todo: Ensure all relevant logging fields are set @@ -27,15 +33,33 @@ type server struct { data code_data.Provider auth *auth_util.RPCSignatureVerifier + streamsMu sync.RWMutex + streams map[string]*chatEventStream + + chatLocks *sync_util.StripedLock + chatEventChans *sync_util.StripedChannel + chatpb.UnimplementedChatServer } func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - return &server{ - log: logrus.StandardLogger().WithField("type", "chat/v2/server"), + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/v2/server"), + data: data, auth: auth, + + streams: make(map[string]*chatEventStream), + + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters + } + + for i, channel := range s.chatEventChans.GetChannels() { + go s.asyncChatEventStreamNotifier(i, channel) } + + return s } func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { @@ -93,10 +117,6 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e return status.Error(codes.InvalidArgument, "open_stream is nil") } - if req.GetOpenStream().Signature == nil { - return status.Error(codes.InvalidArgument, "signature is nil") - } - owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) if err != nil { log.WithError(err).Warn("invalid owner account") @@ -104,13 +124,129 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } log = log.WithField("owner", owner.PublicKey().ToBase58()) + chatId, err := chat.GetChatIdFromProto(req.GetOpenStream().ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.GetOpenStream().MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + signature := req.GetOpenStream().Signature req.GetOpenStream().Signature = nil if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { return err } - return status.Error(codes.Unimplemented, "") + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return status.Error(codes.Unimplemented, "todo: missing result code") + default: + log.WithError(err).Warn("failure getting chat record") + return status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return status.Error(codes.Internal, "") + } else if !ownsChatMember { + return status.Error(codes.Unimplemented, "todo: missing result code") + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) + + s.streamsMu.Lock() + + stream, exists := s.streams[streamKey] + if exists { + s.streamsMu.Unlock() + // There's an existing stream on this server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + stream = newChatEventStream(streamBufferSize) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = stream + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + + for { + select { + case event, ok := <-stream.streamCh: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: &chatpb.ChatStreamEventBatch{ + Events: []*chatpb.ChatStreamEvent{event}, + }, + }, + }) + if err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } } func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { @@ -191,11 +327,11 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.Internal, "") } - isChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") - } else if !isChatMember { + } else if !ownsChatMember { return nil, status.Error(codes.Unimplemented, "todo: missing result code") } @@ -218,6 +354,15 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.Internal, "") } + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: req.Pointer, + }, + } + if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_OK, }, nil diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go index bf986572..ec797a77 100644 --- a/pkg/code/server/grpc/chat/v2/stream.go +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -2,14 +2,75 @@ package chat_v2 import ( "context" + "strings" + "sync" "time" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" ) +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + func boundedStreamChatEventsRecv( ctx context.Context, streamer chatpb.Chat_StreamChatEventsServer, @@ -30,3 +91,86 @@ func boundedStreamChatEventsRecv( return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") } } + +type chatEventNotification struct { + chatId chat.ChatId + memberId chat.MemberId + event *chatpb.ChatStreamEvent + ts time.Time +} + +func (s *server) asyncNotifyAll(chatId chat.ChatId, memberId chat.MemberId, event *chatpb.ChatStreamEvent) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, memberId, m, time.Now()}) + if !ok { + return errors.New("chat event channel is full") + } + return nil +} + +func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log := log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, typedValue.chatId.String()) { + continue + } + + if strings.HasSuffix(key, typedValue.memberId.String()) { + continue + } + + if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} From 8585a38eab6934ddb299af04f45a927fc48b25d3 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 14:07:21 -0400 Subject: [PATCH 19/71] Implement GetMessages RPC --- pkg/code/data/chat/v2/store.go | 7 ++ pkg/code/data/internal.go | 8 ++ pkg/code/server/grpc/chat/v2/server.go | 129 ++++++++++++++++++++++++- 3 files changed, 143 insertions(+), 1 deletion(-) diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index b1d8d089..2abb34e1 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -3,6 +3,8 @@ package chat_v2 import ( "context" "errors" + + "github.com/code-payments/code-server/pkg/database/query" ) var ( @@ -22,6 +24,11 @@ type Store interface { // GetMessageById gets a chat message by the chat and message IDs GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) + // GetAllMessagesByChat gets all messages for a given chat + // + // Note: Cursor is a message ID + GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + // AdvancePointer advances a chat pointer for a chat member AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 20a72601..d7c35b71 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -398,6 +398,7 @@ type DatabaseData interface { GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error @@ -1467,6 +1468,13 @@ func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { return dp.chatv2.GetMessageById(ctx, chatId, messageId) } +func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } + return dp.chatv2.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) +} func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error { return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index a8f0592a..5c8246e5 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -8,8 +8,10 @@ import ( "github.com/pkg/errors" "github.com/sirupsen/logrus" + "golang.org/x/text/language" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -21,11 +23,17 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/twitter" + "github.com/code-payments/code-server/pkg/code/localization" + "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" sync_util "github.com/code-payments/code-server/pkg/sync" ) +const ( + maxGetMessagesPageSize = 100 +) + // todo: Ensure all relevant logging fields are set type server struct { log *logrus.Entry @@ -93,13 +101,132 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - return nil, status.Error(codes.Unimplemented, "") + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetMessagesPageSize + } + if limit > maxGetMessagesPageSize { + limit = maxGetMessagesPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetMessagesRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } + + messageRecords, err := s.data.GetAllChatMessagesV2( + ctx, + chatId, + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + if err == chat.ErrMessageNotFound { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_NOT_FOUND, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat message records") + return nil, status.Error(codes.Internal, "") + } + + var userLocale *language.Tag // Loaded lazily when required + var protoChatMessages []*chatpb.ChatMessage + for _, messageRecord := range messageRecords { + log := log.WithField("message_id", messageRecord.MessageId.String()) + + var protoChatMessage chatpb.ChatMessage + err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) + if err != nil { + log.WithError(err).Warn("failure unmarshalling proto chat message") + return nil, status.Error(codes.Internal, "") + } + + ts, err := messageRecord.GetTimestamp() + if err != nil { + log.WithError(err).Warn("failure getting message timestamp") + return nil, status.Error(codes.Internal, "") + } + + for _, content := range protoChatMessage.Content { + switch typed := content.Type.(type) { + case *chatpb.Content_Localized: + if userLocale == nil { + loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + log.WithError(err).Warn("failure getting user locale") + return nil, status.Error(codes.Internal, "") + } + userLocale = &loadedUserLocale + } + + typed.Localized.KeyOrText = localization.LocalizeWithFallback( + *userLocale, + localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), + typed.Localized.KeyOrText, + ) + } + } + + protoChatMessage.MessageId = &chatpb.ChatMessageId{Value: messageRecord.MessageId[:]} + if messageRecord.Sender != nil { + protoChatMessage.SenderId = &chatpb.ChatMemberId{Value: messageRecord.Sender[:]} + } + protoChatMessage.Ts = timestamppb.New(ts) + protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} + } + + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_OK, + Messages: protoChatMessages, + }, nil } func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { From 24e82186b6e83346d7facfdb366cc846266d4acb Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Fri, 7 Jun 2024 14:08:25 -0400 Subject: [PATCH 20/71] Add reminder to add a flush on StreamChatEvents --- pkg/code/server/grpc/chat/v2/server.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 5c8246e5..fdf3b4fd 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -229,6 +229,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest }, nil } +// todo: Requires a "flush" of most recent messages func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { ctx := streamer.Context() From b6d0756e0a45ca511d802f3f74abf15398ac25fe Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 09:46:44 -0400 Subject: [PATCH 21/71] Add a flush on StreamChatEvents stream open --- pkg/code/server/grpc/chat/v2/server.go | 146 ++++++++++++++++--------- 1 file changed, 92 insertions(+), 54 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index fdf3b4fd..eed7ec8c 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -161,75 +161,25 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest cursor = req.Cursor.Value } - messageRecords, err := s.data.GetAllChatMessagesV2( + protoChatMessages, err := s.getProtoChatMessages( ctx, chatId, + owner, query.WithCursor(cursor), query.WithDirection(direction), query.WithLimit(limit), ) - if err == chat.ErrMessageNotFound { - return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_NOT_FOUND, - }, nil - } else if err != nil { - log.WithError(err).Warn("failure getting chat message records") + if err != nil { + log.WithError(err).Warn("failure getting chat messages") return nil, status.Error(codes.Internal, "") } - var userLocale *language.Tag // Loaded lazily when required - var protoChatMessages []*chatpb.ChatMessage - for _, messageRecord := range messageRecords { - log := log.WithField("message_id", messageRecord.MessageId.String()) - - var protoChatMessage chatpb.ChatMessage - err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) - if err != nil { - log.WithError(err).Warn("failure unmarshalling proto chat message") - return nil, status.Error(codes.Internal, "") - } - - ts, err := messageRecord.GetTimestamp() - if err != nil { - log.WithError(err).Warn("failure getting message timestamp") - return nil, status.Error(codes.Internal, "") - } - - for _, content := range protoChatMessage.Content { - switch typed := content.Type.(type) { - case *chatpb.Content_Localized: - if userLocale == nil { - loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) - if err != nil { - log.WithError(err).Warn("failure getting user locale") - return nil, status.Error(codes.Internal, "") - } - userLocale = &loadedUserLocale - } - - typed.Localized.KeyOrText = localization.LocalizeWithFallback( - *userLocale, - localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), - typed.Localized.KeyOrText, - ) - } - } - - protoChatMessage.MessageId = &chatpb.ChatMessageId{Value: messageRecord.MessageId[:]} - if messageRecord.Sender != nil { - protoChatMessage.SenderId = &chatpb.ChatMemberId{Value: messageRecord.Sender[:]} - } - protoChatMessage.Ts = timestamppb.New(ts) - protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} - } - return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_OK, Messages: protoChatMessages, }, nil } -// todo: Requires a "flush" of most recent messages func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { ctx := streamer.Context() @@ -331,6 +281,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + go s.flush(ctx, chatId, owner, stream) + for { select { case event, ok := <-stream.streamCh: @@ -622,6 +574,92 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } +func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flush", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } + + for _, protoChatMessage := range protoChatMessages { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: protoChatMessage, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} + +func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { + messageRecords, err := s.data.GetAllChatMessagesV2( + ctx, + chatId, + queryOptions..., + ) + if err == chat.ErrMessageNotFound { + return nil, err + } + + var userLocale *language.Tag // Loaded lazily when required + var res []*chatpb.ChatMessage + for _, messageRecord := range messageRecords { + var protoChatMessage chatpb.ChatMessage + err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) + if err != nil { + return nil, errors.Wrap(err, "error unmarshalling proto chat message") + } + + ts, err := messageRecord.GetTimestamp() + if err != nil { + return nil, errors.Wrap(err, "error getting message timestamp") + } + + for _, content := range protoChatMessage.Content { + switch typed := content.Type.(type) { + case *chatpb.Content_Localized: + if userLocale == nil { + loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + return nil, errors.Wrap(err, "error getting user locale") + } + userLocale = &loadedUserLocale + } + + typed.Localized.KeyOrText = localization.LocalizeWithFallback( + *userLocale, + localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), + typed.Localized.KeyOrText, + ) + } + } + + protoChatMessage.MessageId = &chatpb.ChatMessageId{Value: messageRecord.MessageId[:]} + if messageRecord.Sender != nil { + protoChatMessage.SenderId = &chatpb.ChatMemberId{Value: messageRecord.Sender[:]} + } + protoChatMessage.Ts = timestamppb.New(ts) + protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} + + res = append(res, &protoChatMessage) + } + + return res, nil +} + func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { From 451b155531cbf265409defdd44bdf7606003616d Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 09:48:26 -0400 Subject: [PATCH 22/71] Fix todo comment --- pkg/code/server/grpc/chat/v2/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index eed7ec8c..c8f7459a 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -542,7 +542,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr return nil, err } - // todo: Use chat record to determine if muting is allowed + // todo: Use chat record to determine if unsubscribing is allowed _, err = s.data.GetChatByIdV2(ctx, chatId) switch err { case nil: From c6d5b884c1596311e87a911d988ae2f6ee972ab9 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 09:53:37 -0400 Subject: [PATCH 23/71] Add missing query parameters to flush --- pkg/code/server/grpc/chat/v2/server.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index c8f7459a..30e37b23 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -32,6 +32,7 @@ import ( const ( maxGetMessagesPageSize = 100 + flushMessageCount = 100 ) // todo: Ensure all relevant logging fields are set @@ -581,10 +582,15 @@ func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Ac "owner_account": owner.PublicKey().ToBase58(), }) + cursorValue := chat.GenerateMessageIdAtTime(time.Now().Add(2 * time.Second)) + protoChatMessages, err := s.getProtoChatMessages( ctx, chatId, owner, + query.WithCursor(cursorValue[:]), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), ) if err != nil { log.WithError(err).Warn("failure getting chat messages") From e29f89908d1eff48e73462fd5a4230fee5bc88b2 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 10:35:29 -0400 Subject: [PATCH 24/71] Implement the SendMessage RPC without consideration for other chat features --- pkg/code/data/chat/v2/id.go | 21 ++- pkg/code/data/chat/v2/store.go | 3 + pkg/code/data/internal.go | 4 + pkg/code/server/grpc/chat/v2/server.go | 214 ++++++++++++++++++++----- 4 files changed, 200 insertions(+), 42 deletions(-) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index 267ae7f2..b9ef4667 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -29,6 +29,11 @@ func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { return typed, nil } +// ToProto converts a chat ID to its protobuf variant +func (c ChatId) ToProto() *chatpb.ChatId { + return &chatpb.ChatId{Value: c[:]} +} + // Validate validates a chat ID func (c ChatId) Validate() error { return nil @@ -42,6 +47,11 @@ func (c ChatId) String() string { // Random UUIDv4 ID for chat members type MemberId uuid.UUID +// GenerateMemberId generates a new random chat member ID +func GenerateMemberId() MemberId { + return MemberId(uuid.New()) +} + // GetMemberIdFromProto gets a member ID from the protobuf variant func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { if err := proto.Validate(); err != nil { @@ -58,9 +68,9 @@ func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { return typed, nil } -// GenerateMemberId generates a new random chat member ID -func GenerateMemberId() MemberId { - return MemberId(uuid.New()) +// ToProto converts a message ID to its protobuf variant +func (m MemberId) ToProto() *chatpb.ChatMemberId { + return &chatpb.ChatMemberId{Value: m[:]} } // Validate validates a chat member ID @@ -129,6 +139,11 @@ func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { return typed, nil } +// ToProto converts a message ID to its protobuf variant +func (m MessageId) ToProto() *chatpb.ChatMessageId { + return &chatpb.ChatMessageId{Value: m[:]} +} + // GetTimestamp gets the encoded timestamp in the message ID func (m MessageId) GetTimestamp() (time.Time, error) { if err := m.Validate(); err != nil { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 2abb34e1..54e00932 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -29,6 +29,9 @@ type Store interface { // Note: Cursor is a message ID GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + // PutMessage creates a new chat message + PutMessage(ctx context.Context, record *MessageRecord) error + // AdvancePointer advances a chat pointer for a chat member AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index d7c35b71..f2d1af0c 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -399,6 +399,7 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) + PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error @@ -1475,6 +1476,9 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } +func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { + return dp.chatv2.PutMessage(ctx, record) +} func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error { return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 30e37b23..b3f815fe 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -330,6 +330,41 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } } +func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flush", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) + + cursorValue := chat.GenerateMessageIdAtTime(time.Now().Add(2 * time.Second)) + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(cursorValue[:]), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } + + for _, protoChatMessage := range protoChatMessages { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: protoChatMessage, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} + func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { log := s.log.WithField("method", "SendMessage") log = client.InjectLoggingMetadata(ctx, log) @@ -339,7 +374,29 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest log.WithError(err).Warn("invalid owner account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("owner", owner.PublicKey().ToBase58()) + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + switch req.Content[0].Type.(type) { + case *chatpb.Content_Text, *chatpb.Content_ThankYou: + default: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_INVALID_CONTENT_TYPE, + }, nil + } signature := req.Signature req.Signature = nil @@ -347,7 +404,121 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, err } - return nil, status.Error(codes.Unimplemented, "") + chatRecord, err := s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + default: + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_INVALID_CHAT_TYPE, + }, nil + } + + ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return nil, status.Error(codes.Unimplemented, "todo: missing result code") + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + chatMessage := &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: req.MemberId, + Content: req.Content, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } + + err = s.persistChatMessage(ctx, chatId, chatMessage) + if err != nil { + log.WithError(err).Warn("failure persisting chat message") + return nil, status.Error(codes.Internal, "") + } + + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, + } + if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + // todo: send the push + + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_OK, + Message: chatMessage, + }, nil +} + +// todo: This belongs in the common chat utility, which currently only operates on v1 chats +func (s *server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { + if err := protoChatMessage.Validate(); err != nil { + return errors.Wrap(err, "proto chat message failed validation") + } + + messageId, err := chat.GetMessageIdFromProto(protoChatMessage.MessageId) + if err != nil { + return errors.Wrap(err, "invalid message id") + } + + var senderId *chat.MemberId + if protoChatMessage.SenderId != nil { + convertedSenderId, err := chat.GetMemberIdFromProto(protoChatMessage.SenderId) + if err != nil { + return errors.Wrap(err, "invalid member id") + } + senderId = &convertedSenderId + } + + // Clear out extracted metadata as a space optimization + cloned := proto.Clone(protoChatMessage).(*chatpb.ChatMessage) + cloned.MessageId = nil + cloned.SenderId = nil + cloned.Ts = nil + cloned.Cursor = nil + + marshalled, err := proto.Marshal(cloned) + if err != nil { + return errors.Wrap(err, "error marshalling proto chat message") + } + + // todo: Doesn't incoroporate reference. We might want to promote the proto a level above the content. + messageRecord := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + + Sender: senderId, + + Data: marshalled, + + IsSilent: false, + } + + err = s.data.PutChatMessageV2(ctx, messageRecord) + if err != nil { + return errors.Wrap(err, "error persiting chat message") + } + return nil } func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { @@ -575,41 +746,6 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } -func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { - log := s.log.WithFields(logrus.Fields{ - "method": "flush", - "chat_id": chatId.String(), - "owner_account": owner.PublicKey().ToBase58(), - }) - - cursorValue := chat.GenerateMessageIdAtTime(time.Now().Add(2 * time.Second)) - - protoChatMessages, err := s.getProtoChatMessages( - ctx, - chatId, - owner, - query.WithCursor(cursorValue[:]), - query.WithDirection(query.Descending), - query.WithLimit(flushMessageCount), - ) - if err != nil { - log.WithError(err).Warn("failure getting chat messages") - return - } - - for _, protoChatMessage := range protoChatMessages { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: protoChatMessage, - }, - } - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - return - } - } -} - func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { messageRecords, err := s.data.GetAllChatMessagesV2( ctx, @@ -653,9 +789,9 @@ func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o } } - protoChatMessage.MessageId = &chatpb.ChatMessageId{Value: messageRecord.MessageId[:]} + protoChatMessage.MessageId = messageRecord.MessageId.ToProto() if messageRecord.Sender != nil { - protoChatMessage.SenderId = &chatpb.ChatMemberId{Value: messageRecord.Sender[:]} + protoChatMessage.SenderId = messageRecord.Sender.ToProto() } protoChatMessage.Ts = timestamppb.New(ts) protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} From db4877c6f8667f1d1bcdbc78b3750cfa0a8d6617 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 15:07:26 -0400 Subject: [PATCH 25/71] Skeleton for minimal memory PoC chat v2 data store --- pkg/code/data/chat/v2/memory/store.go | 78 ++++++++++++++++++++++ pkg/code/data/chat/v2/memory/store_test.go | 15 +++++ pkg/code/data/chat/v2/store.go | 6 ++ pkg/code/data/chat/v2/tests/tests.go | 14 ++++ pkg/code/data/internal.go | 13 +++- 5 files changed, 124 insertions(+), 2 deletions(-) create mode 100644 pkg/code/data/chat/v2/memory/store.go create mode 100644 pkg/code/data/chat/v2/memory/store_test.go create mode 100644 pkg/code/data/chat/v2/tests/tests.go diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go new file mode 100644 index 00000000..cdd8e4b1 --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store.go @@ -0,0 +1,78 @@ +package memory + +import ( + "context" + "errors" + "sync" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" +) + +// todo: implement me +type store struct { + mu sync.Mutex + last uint64 +} + +// New returns a new in memory chat.Store +func New() chat.Store { + return &store{} +} + +// GetChatById implements chat.Store.GetChatById +func (s *store) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.ChatRecord, error) { + return nil, errors.New("not implemented") +} + +// GetMemberById implements chat.Store.GetMemberById +func (s *store) GetMemberById(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId) (*chat.MemberRecord, error) { + return nil, errors.New("not implemented") +} + +// GetMessageById implements chat.Store.GetMessageById +func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { + return nil, errors.New("not implemented") +} + +// GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat +func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { + return nil, errors.New("not implemented") +} + +// PutChat creates a new chat +func (s *store) PutChat(ctx context.Context, record *chat.ChatRecord) error { + return errors.New("not implemented") +} + +// PutMember creates a new chat member +func (s *store) PutMember(ctx context.Context, record *chat.MemberRecord) error { + return errors.New("not implemented") +} + +// PutMessage implements chat.Store.PutMessage +func (s *store) PutMessage(ctx context.Context, record *chat.MessageRecord) error { + return errors.New("not implemented") +} + +// AdvancePointer implements chat.Store.AdvancePointer +func (s *store) AdvancePointer(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) error { + return errors.New("not implemented") +} + +// SetMuteState implements chat.Store.SetMuteState +func (s *store) SetMuteState(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { + return errors.New("not implemented") +} + +// SetSubscriptionState implements chat.Store.SetSubscriptionState +func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isSubscribed bool) error { + return errors.New("not implemented") +} + +func (s *store) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.last = 0 +} diff --git a/pkg/code/data/chat/v2/memory/store_test.go b/pkg/code/data/chat/v2/memory/store_test.go new file mode 100644 index 00000000..cd61dfa4 --- /dev/null +++ b/pkg/code/data/chat/v2/memory/store_test.go @@ -0,0 +1,15 @@ +package memory + +import ( + "testing" + + "github.com/code-payments/code-server/pkg/code/data/chat/v2/tests" +) + +func TestChatMemoryStore(t *testing.T) { + testStore := New() + teardown := func() { + testStore.(*store).reset() + } + tests.RunTests(t, testStore, teardown) +} diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 54e00932..ab73943e 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -29,6 +29,12 @@ type Store interface { // Note: Cursor is a message ID GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + // PutChat creates a new chat + PutChat(ctx context.Context, record *ChatRecord) error + + // PutMember creates a new chat member + PutMember(ctx context.Context, record *MemberRecord) error + // PutMessage creates a new chat message PutMessage(ctx context.Context, record *MessageRecord) error diff --git a/pkg/code/data/chat/v2/tests/tests.go b/pkg/code/data/chat/v2/tests/tests.go new file mode 100644 index 00000000..94c85a89 --- /dev/null +++ b/pkg/code/data/chat/v2/tests/tests.go @@ -0,0 +1,14 @@ +package tests + +import ( + "testing" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" +) + +func RunTests(t *testing.T, s chat.Store, teardown func()) { + for _, tf := range []func(t *testing.T, s chat.Store){} { + tf(t, s) + teardown() + } +} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index f2d1af0c..68cce410 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -61,6 +61,7 @@ import ( badgecount_memory_client "github.com/code-payments/code-server/pkg/code/data/badgecount/memory" balance_memory_client "github.com/code-payments/code-server/pkg/code/data/balance/memory" chat_v1_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v1/memory" + chat_v2_memory_client "github.com/code-payments/code-server/pkg/code/data/chat/v2/memory" commitment_memory_client "github.com/code-payments/code-server/pkg/code/data/commitment/memory" contact_memory_client "github.com/code-payments/code-server/pkg/code/data/contact/memory" currency_memory_client "github.com/code-payments/code-server/pkg/code/data/currency/memory" @@ -399,6 +400,8 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) + PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error + PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error @@ -546,7 +549,7 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { event: event_postgres_client.New(db), webhook: webhook_postgres_client.New(db), chatv1: chat_v1_postgres_client.New(db), - chatv2: nil, // todo: Initialize me + chatv2: chat_v2_memory_client.New(), // todo: Postgres version for production after PoC badgecount: badgecount_postgres_client.New(db), login: login_postgres_client.New(db), balance: balance_postgres_client.New(db), @@ -590,7 +593,7 @@ func NewTestDatabaseProvider() DatabaseData { event: event_memory_client.New(), webhook: webhook_memory_client.New(), chatv1: chat_v1_memory_client.New(), - chatv2: nil, // todo: initialize me + chatv2: chat_v2_memory_client.New(), badgecount: badgecount_memory_client.New(), login: login_memory_client.New(), balance: balance_memory_client.New(), @@ -1476,6 +1479,12 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } +func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { + return dp.chatv2.PutChat(ctx, record) +} +func (dp *DatabaseProvider) PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error { + return dp.chatv2.PutMember(ctx, record) +} func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { return dp.chatv2.PutMessage(ctx, record) } From e30053337f8087370c607ad7dae4f5f4edb7e5ec Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 10 Jun 2024 15:26:00 -0400 Subject: [PATCH 26/71] Implement chat v2 memory data store (WIP) --- pkg/code/data/chat/v2/memory/store.go | 173 +++++++++++++++++++++++--- pkg/code/data/chat/v2/store.go | 7 +- 2 files changed, 157 insertions(+), 23 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index cdd8e4b1..e3802aa0 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -1,6 +1,7 @@ package memory import ( + "bytes" "context" "errors" "sync" @@ -9,10 +10,17 @@ import ( "github.com/code-payments/code-server/pkg/database/query" ) -// todo: implement me +// todo: finish implementing me type store struct { - mu sync.Mutex - last uint64 + mu sync.Mutex + + chatRecords []*chat.ChatRecord + memberRecords []*chat.MemberRecord + messageRecords []*chat.MessageRecord + + lastChatId uint64 + lastMemberId uint64 + lastMessageId uint64 } // New returns a new in memory chat.Store @@ -21,58 +29,183 @@ func New() chat.Store { } // GetChatById implements chat.Store.GetChatById -func (s *store) GetChatById(ctx context.Context, chatId chat.ChatId) (*chat.ChatRecord, error) { - return nil, errors.New("not implemented") +func (s *store) GetChatById(_ context.Context, chatId chat.ChatId) (*chat.ChatRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findChatById(chatId) + if item == nil { + return nil, chat.ErrChatNotFound + } + + cloned := item.Clone() + return &cloned, nil } // GetMemberById implements chat.Store.GetMemberById -func (s *store) GetMemberById(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId) (*chat.MemberRecord, error) { - return nil, errors.New("not implemented") +func (s *store) GetMemberById(_ context.Context, chatId chat.ChatId, memberId chat.MemberId) (*chat.MemberRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return nil, chat.ErrMemberNotFound + } + + cloned := item.Clone() + return &cloned, nil } // GetMessageById implements chat.Store.GetMessageById -func (s *store) GetMessageById(ctx context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { - return nil, errors.New("not implemented") +func (s *store) GetMessageById(_ context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMessageById(chatId, messageId) + if item == nil { + return nil, chat.ErrMessageNotFound + } + + cloned := item.Clone() + return &cloned, nil } // GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat -func (s *store) GetAllMessagesByChat(ctx context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { +func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + return nil, errors.New("not implemented") } // PutChat creates a new chat -func (s *store) PutChat(ctx context.Context, record *chat.ChatRecord) error { +func (s *store) PutChat(_ context.Context, record *chat.ChatRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + return errors.New("not implemented") } // PutMember creates a new chat member -func (s *store) PutMember(ctx context.Context, record *chat.MemberRecord) error { +func (s *store) PutMember(_ context.Context, record *chat.MemberRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + return errors.New("not implemented") } // PutMessage implements chat.Store.PutMessage -func (s *store) PutMessage(ctx context.Context, record *chat.MessageRecord) error { +func (s *store) PutMessage(_ context.Context, record *chat.MessageRecord) error { + s.mu.Lock() + defer s.mu.Unlock() + return errors.New("not implemented") } // AdvancePointer implements chat.Store.AdvancePointer -func (s *store) AdvancePointer(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) error { - return errors.New("not implemented") +func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) error { + switch pointerType { + case chat.PointerTypeDelivered, chat.PointerTypeRead: + default: + return chat.ErrInvalidPointerType + } + + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + + var currentPointer *chat.MessageId + switch pointerType { + case chat.PointerTypeDelivered: + currentPointer = item.DeliveryPointer + case chat.PointerTypeRead: + currentPointer = item.ReadPointer + } + + if currentPointer != nil && currentPointer.After(pointer) { + return nil + } + + switch pointerType { + case chat.PointerTypeDelivered: + item.DeliveryPointer = &pointer // todo: pointer copy safety + case chat.PointerTypeRead: + item.ReadPointer = &pointer // todo: pointer copy safety + } + + return nil } // SetMuteState implements chat.Store.SetMuteState -func (s *store) SetMuteState(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { - return errors.New("not implemented") +func (s *store) SetMuteState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + + item.IsMuted = isMuted + + return nil } // SetSubscriptionState implements chat.Store.SetSubscriptionState -func (s *store) SetSubscriptionState(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isSubscribed bool) error { - return errors.New("not implemented") +func (s *store) SetSubscriptionState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isSubscribed bool) error { + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + + item.IsUnsubscribed = !isSubscribed + + return nil +} + +func (s *store) findChatById(chatId chat.ChatId) *chat.ChatRecord { + for _, item := range s.chatRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + return item + } + } + return nil +} + +func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat.MemberRecord { + for _, item := range s.memberRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(memberId[:], item.MemberId[:]) { + return item + } + } + return nil +} + +func (s *store) findMessageById(chatId chat.ChatId, messageId chat.MessageId) *chat.MessageRecord { + for _, item := range s.messageRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(messageId[:], item.MessageId[:]) { + return item + } + } + return nil } func (s *store) reset() { s.mu.Lock() defer s.mu.Unlock() - s.last = 0 + s.chatRecords = nil + s.memberRecords = nil + s.messageRecords = nil + + s.lastChatId = 0 + s.lastMemberId = 0 + s.lastMessageId = 0 } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index ab73943e..e7234b3a 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -8,9 +8,10 @@ import ( ) var ( - ErrChatNotFound = errors.New("chat not found") - ErrMemberNotFound = errors.New("chat member not found") - ErrMessageNotFound = errors.New("chat message not found") + ErrChatNotFound = errors.New("chat not found") + ErrMemberNotFound = errors.New("chat member not found") + ErrMessageNotFound = errors.New("chat message not found") + ErrInvalidPointerType = errors.New("invalid pointer type") ) // todo: Define interface methods From f1f2568a9f3e7668441c018263b693e8585f4431 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 10:05:42 -0400 Subject: [PATCH 27/71] Implement remaining chat memory store defined methods --- pkg/code/data/chat/v2/id.go | 57 ++++++--- pkg/code/data/chat/v2/memory/store.go | 177 ++++++++++++++++++++++++-- pkg/code/data/chat/v2/model.go | 8 ++ pkg/code/data/chat/v2/store.go | 3 + 4 files changed, 222 insertions(+), 23 deletions(-) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index b9ef4667..4f7b7b0a 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -13,14 +13,14 @@ import ( type ChatId [32]byte -// GetChatIdFromProto gets a chat ID from the protobuf variant -func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { - if err := proto.Validate(); err != nil { - return ChatId{}, errors.Wrap(err, "proto validation failed") +// GetChatIdFromBytes gets a chat ID from a byte buffer +func GetChatIdFromBytes(buffer []byte) (ChatId, error) { + if len(buffer) != 32 { + return ChatId{}, errors.New("chat id must be 32 bytes in length") } var typed ChatId - copy(typed[:], proto.Value) + copy(typed[:], buffer[:]) if err := typed.Validate(); err != nil { return ChatId{}, errors.Wrap(err, "invalid chat id") @@ -29,6 +29,15 @@ func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { return typed, nil } +// GetChatIdFromProto gets a chat ID from the protobuf variant +func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { + if err := proto.Validate(); err != nil { + return ChatId{}, errors.Wrap(err, "proto validation failed") + } + + return GetChatIdFromBytes(proto.Value) +} + // ToProto converts a chat ID to its protobuf variant func (c ChatId) ToProto() *chatpb.ChatId { return &chatpb.ChatId{Value: c[:]} @@ -52,14 +61,14 @@ func GenerateMemberId() MemberId { return MemberId(uuid.New()) } -// GetMemberIdFromProto gets a member ID from the protobuf variant -func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { - if err := proto.Validate(); err != nil { - return MemberId{}, errors.Wrap(err, "proto validation failed") +// GetMemberIdFromBytes gets a member ID from a byte buffer +func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { + if len(buffer) != 16 { + return MemberId{}, errors.New("member id must be 16 bytes in length") } var typed MemberId - copy(typed[:], proto.Value) + copy(typed[:], buffer[:]) if err := typed.Validate(); err != nil { return MemberId{}, errors.Wrap(err, "invalid member id") @@ -68,6 +77,15 @@ func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { return typed, nil } +// GetMemberIdFromProto gets a member ID from the protobuf variant +func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { + if err := proto.Validate(); err != nil { + return MemberId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMemberIdFromBytes(proto.Value) +} + // ToProto converts a message ID to its protobuf variant func (m MemberId) ToProto() *chatpb.ChatMemberId { return &chatpb.ChatMemberId{Value: m[:]} @@ -123,14 +141,14 @@ func GenerateMessageIdAtTime(ts time.Time) MessageId { return MessageId(uuidBytes) } -// GetMessageIdFromProto gets a message ID from the protobuf variant -func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { - if err := proto.Validate(); err != nil { - return MessageId{}, errors.Wrap(err, "proto validation failed") +// GetMessageIdFromBytes gets a message ID from a byte buffer +func GetMessageIdFromBytes(buffer []byte) (MessageId, error) { + if len(buffer) != 16 { + return MessageId{}, errors.New("message id must be 16 bytes in length") } var typed MessageId - copy(typed[:], proto.Value) + copy(typed[:], buffer[:]) if err := typed.Validate(); err != nil { return MessageId{}, errors.Wrap(err, "invalid message id") @@ -139,6 +157,15 @@ func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { return typed, nil } +// GetMessageIdFromProto gets a message ID from the protobuf variant +func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { + if err := proto.Validate(); err != nil { + return MessageId{}, errors.Wrap(err, "proto validation failed") + } + + return GetMessageIdFromBytes(proto.Value) +} + // ToProto converts a message ID to its protobuf variant func (m MessageId) ToProto() *chatpb.ChatMessageId { return &chatpb.ChatMessageId{Value: m[:]} diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index e3802aa0..c5feb913 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -3,8 +3,9 @@ package memory import ( "bytes" "context" - "errors" + "sort" "sync" + "time" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/database/query" @@ -18,9 +19,9 @@ type store struct { memberRecords []*chat.MemberRecord messageRecords []*chat.MessageRecord - lastChatId uint64 - lastMemberId uint64 - lastMessageId uint64 + lastChatId int64 + lastMemberId int64 + lastMessageId int64 } // New returns a new in memory chat.Store @@ -75,31 +76,91 @@ func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, curs s.mu.Lock() defer s.mu.Unlock() - return nil, errors.New("not implemented") + items := s.findMessagesByChatId(chatId) + items, err := s.getMessageRecordPage(items, cursor, direction, limit) + if err != nil { + return nil, err + } + if len(items) == 0 { + return nil, chat.ErrMessageNotFound + } + + return cloneMessageRecords(items), nil } // PutChat creates a new chat func (s *store) PutChat(_ context.Context, record *chat.ChatRecord) error { + if err := record.Validate(); err != nil { + return err + } + s.mu.Lock() defer s.mu.Unlock() - return errors.New("not implemented") + s.lastChatId++ + + if item := s.findChat(record); item != nil { + return chat.ErrChatExists + } + + record.Id = s.lastChatId + if record.CreatedAt.IsZero() { + record.CreatedAt = time.Now() + } + + cloned := record.Clone() + s.chatRecords = append(s.chatRecords, &cloned) + + return nil } // PutMember creates a new chat member func (s *store) PutMember(_ context.Context, record *chat.MemberRecord) error { + if err := record.Validate(); err != nil { + return err + } + s.mu.Lock() defer s.mu.Unlock() - return errors.New("not implemented") + s.lastMemberId++ + + if item := s.findMember(record); item != nil { + return chat.ErrMemberExists + } + + record.Id = s.lastMemberId + if record.JoinedAt.IsZero() { + record.JoinedAt = time.Now() + } + + cloned := record.Clone() + s.memberRecords = append(s.memberRecords, &cloned) + + return nil } // PutMessage implements chat.Store.PutMessage func (s *store) PutMessage(_ context.Context, record *chat.MessageRecord) error { + if err := record.Validate(); err != nil { + return err + } + s.mu.Lock() defer s.mu.Unlock() - return errors.New("not implemented") + s.lastMessageId++ + + if item := s.findMessage(record); item != nil { + return chat.ErrMessageExsits + } + + record.Id = s.lastMessageId + + cloned := record.Clone() + s.messageRecords = append(s.messageRecords, &cloned) + + return nil } // AdvancePointer implements chat.Store.AdvancePointer @@ -170,6 +231,19 @@ func (s *store) SetSubscriptionState(_ context.Context, chatId chat.ChatId, memb return nil } +func (s *store) findChat(data *chat.ChatRecord) *chat.ChatRecord { + for _, item := range s.chatRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) { + return item + } + } + return nil +} + func (s *store) findChatById(chatId chat.ChatId) *chat.ChatRecord { for _, item := range s.chatRecords { if bytes.Equal(chatId[:], item.ChatId[:]) { @@ -179,6 +253,19 @@ func (s *store) findChatById(chatId chat.ChatId) *chat.ChatRecord { return nil } +func (s *store) findMember(data *chat.MemberRecord) *chat.MemberRecord { + for _, item := range s.memberRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MemberId[:], item.MemberId[:]) { + return item + } + } + return nil +} + func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat.MemberRecord { for _, item := range s.memberRecords { if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(memberId[:], item.MemberId[:]) { @@ -188,6 +275,19 @@ func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat return nil } +func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord { + for _, item := range s.messageRecords { + if data.Id == item.Id { + return item + } + + if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MessageId[:], item.MessageId[:]) { + return item + } + } + return nil +} + func (s *store) findMessageById(chatId chat.ChatId, messageId chat.MessageId) *chat.MessageRecord { for _, item := range s.messageRecords { if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(messageId[:], item.MessageId[:]) { @@ -197,6 +297,58 @@ func (s *store) findMessageById(chatId chat.ChatId, messageId chat.MessageId) *c return nil } +func (s *store) findMessagesByChatId(chatId chat.ChatId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range s.messageRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + res = append(res, item) + } + } + return res +} + +func (s *store) getMessageRecordPage(items []*chat.MessageRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { + if len(items) == 0 { + return nil, nil + } + + var messageIdCursor *chat.MessageId + if len(cursor) > 0 { + messageId, err := chat.GetMessageIdFromBytes(cursor) + if err != nil { + return nil, err + } + messageIdCursor = &messageId + } + + var res []*chat.MessageRecord + if messageIdCursor == nil { + res = items + } else { + for _, item := range items { + if item.MessageId.After(*messageIdCursor) && direction == query.Ascending { + res = append(res, item) + } + + if item.MessageId.Before(*messageIdCursor) && direction == query.Descending { + res = append(res, item) + } + } + } + + if direction == query.Ascending { + sort.Sort(chat.MessagesById(res)) + } else { + sort.Sort(sort.Reverse(chat.MessagesById(res))) + } + + if len(res) >= int(limit) { + return res[:limit], nil + } + + return res, nil +} + func (s *store) reset() { s.mu.Lock() defer s.mu.Unlock() @@ -209,3 +361,12 @@ func (s *store) reset() { s.lastMemberId = 0 s.lastMessageId = 0 } + +func cloneMessageRecords(items []*chat.MessageRecord) []*chat.MessageRecord { + res := make([]*chat.MessageRecord, len(items)) + for i, item := range items { + cloned := item.Clone() + res[i] = &cloned + } + return res +} diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index d68a76eb..47ce577c 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -97,6 +97,14 @@ type MessageRecord struct { // Note: No timestamp field, since it's encoded in MessageId } +type MessagesById []*MessageRecord + +func (a MessagesById) Len() int { return len(a) } +func (a MessagesById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MessagesById) Less(i, j int) bool { + return a[i].MessageId.Before(a[i].MessageId) +} + // GetChatIdFromProto gets a chat ID from the protobuf variant func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { switch proto { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index e7234b3a..2312e966 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -8,8 +8,11 @@ import ( ) var ( + ErrChatExists = errors.New("chat already exists") ErrChatNotFound = errors.New("chat not found") + ErrMemberExists = errors.New("chat member already exists") ErrMemberNotFound = errors.New("chat member not found") + ErrMessageExsits = errors.New("chat message already exists") ErrMessageNotFound = errors.New("chat message not found") ErrInvalidPointerType = errors.New("invalid pointer type") ) From f6edb857fe8a48b118081ed02ba4f931494c997c Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 11:14:13 -0400 Subject: [PATCH 28/71] Add missing result codes and update/comment on flush --- pkg/code/server/grpc/chat/v2/server.go | 56 ++++++++++++++++++++------ 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index b3f815fe..5796e312 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -126,7 +126,9 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest switch err { case nil: case chat.ErrChatNotFound: - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + }, nil default: log.WithError(err).Warn("failure getting chat record") return nil, status.Error(codes.Internal, "") @@ -137,7 +139,9 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") } else if !ownsChatMember { - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_DENIED, + }, nil } var limit uint64 @@ -175,6 +179,11 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, status.Error(codes.Internal, "") } + if len(protoChatMessages) == 0 { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + }, nil + } return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_OK, Messages: protoChatMessages, @@ -227,7 +236,11 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e switch err { case nil: case chat.ErrChatNotFound: - return status.Error(codes.Unimplemented, "todo: missing result code") + return streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Error{ + Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_CHAT_NOT_FOUND}, + }, + }) default: log.WithError(err).Warn("failure getting chat record") return status.Error(codes.Internal, "") @@ -238,7 +251,11 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e log.WithError(err).Warn("failure determing chat member ownership") return status.Error(codes.Internal, "") } else if !ownsChatMember { - return status.Error(codes.Unimplemented, "todo: missing result code") + return streamer.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Error{ + Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_DENIED}, + }, + }) } streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) @@ -282,7 +299,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) - go s.flush(ctx, chatId, owner, stream) + // todo: We should also "flush" pointers for each chat member + go s.flushMessages(ctx, chatId, owner, stream) for { select { @@ -330,9 +348,9 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } } -func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { +func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { log := s.log.WithFields(logrus.Fields{ - "method": "flush", + "method": "flushMessages", "chat_id": chatId.String(), "owner_account": owner.PublicKey().ToBase58(), }) @@ -347,7 +365,9 @@ func (s *server) flush(ctx context.Context, chatId chat.ChatId, owner *common.Ac query.WithDirection(query.Descending), query.WithLimit(flushMessageCount), ) - if err != nil { + if err == chat.ErrMessageNotFound { + return + } else if err != nil { log.WithError(err).Warn("failure getting chat messages") return } @@ -429,7 +449,9 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") } else if !ownsChatMember { - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.SendMessageResponse{ + Result: chatpb.SendMessageResponse_DENIED, + }, nil } chatLock := s.chatLocks.Get(chatId[:]) @@ -551,7 +573,9 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR switch pointerType { case chat.PointerTypeDelivered, chat.PointerTypeRead: default: - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_INVALID_POINTER_TYPE, + }, nil } pointerValue, err := chat.GetMessageIdFromProto(req.Pointer.Value) @@ -584,7 +608,9 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") } else if !ownsChatMember { - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.AdvancePointerResponse{ + Result: chatpb.AdvancePointerResponse_DENIED, + }, nil } _, err = s.data.GetChatMessageByIdV2(ctx, chatId, pointerValue) @@ -669,7 +695,9 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") } else if !isChatMember { - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_DENIED, + }, nil } err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) @@ -732,7 +760,9 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") } else if !ownsChatMember { - return nil, status.Error(codes.Unimplemented, "todo: missing result code") + return &chatpb.SetSubscriptionStateResponse{ + Result: chatpb.SetSubscriptionStateResponse_DENIED, + }, nil } err = s.data.SetChatSubscriptionStateV2(ctx, chatId, memberId, req.IsSubscribed) From ac1fe57a0442338c951fac2e6d92d9ae139a6d52 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 11:52:16 -0400 Subject: [PATCH 29/71] Flush pointers on chat event stream open --- pkg/code/data/chat/v2/memory/store.go | 34 +++++++++++++++++-- pkg/code/data/chat/v2/model.go | 14 ++++++++ pkg/code/data/chat/v2/store.go | 9 ++++-- pkg/code/data/internal.go | 6 +++- pkg/code/server/grpc/chat/v2/server.go | 45 +++++++++++++++++++++++++- 5 files changed, 101 insertions(+), 7 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index c5feb913..ce1184b5 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -71,8 +71,17 @@ func (s *store) GetMessageById(_ context.Context, chatId chat.ChatId, messageId return &cloned, nil } -// GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat -func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { +// GetAllMembersByChatId implements chat.Store.GetAllMembersByChatId +func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) { + items := s.findMembersByChatId(chatId) + if len(items) == 0 { + return nil, chat.ErrMemberNotFound + } + return cloneMemberRecords(items), nil +} + +// GetAllMessagesByChatId implements chat.Store.GetAllMessagesByChatId +func (s *store) GetAllMessagesByChatId(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { s.mu.Lock() defer s.mu.Unlock() @@ -81,10 +90,10 @@ func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, curs if err != nil { return nil, err } + if len(items) == 0 { return nil, chat.ErrMessageNotFound } - return cloneMessageRecords(items), nil } @@ -275,6 +284,16 @@ func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat return nil } +func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { + var res []*chat.MemberRecord + for _, item := range s.memberRecords { + if bytes.Equal(chatId[:], item.ChatId[:]) { + res = append(res, item) + } + } + return res +} + func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord { for _, item := range s.messageRecords { if data.Id == item.Id { @@ -362,6 +381,15 @@ func (s *store) reset() { s.lastMessageId = 0 } +func cloneMemberRecords(items []*chat.MemberRecord) []*chat.MemberRecord { + res := make([]*chat.MemberRecord, len(items)) + for i, item := range items { + cloned := item.Clone() + res[i] = &cloned + } + return res +} + func cloneMessageRecords(items []*chat.MessageRecord) []*chat.MessageRecord { res := make([]*chat.MessageRecord, len(items)) for i, item := range items { diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index 47ce577c..39df38a8 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -119,6 +119,20 @@ func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { } } +// ToProto returns the proto representation of the pointer type +func (p PointerType) ToProto() chatpb.Pointer_Kind { + switch p { + case PointerTypeSent: + return chatpb.Pointer_SENT + case PointerTypeDelivered: + return chatpb.Pointer_DELIVERED + case PointerTypeRead: + return chatpb.Pointer_READ + default: + return chatpb.Pointer_UNKNOWN + } +} + // String returns the string representation of the pointer type func (p PointerType) String() string { switch p { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 2312e966..7ecc10fc 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -28,10 +28,15 @@ type Store interface { // GetMessageById gets a chat message by the chat and message IDs GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) - // GetAllMessagesByChat gets all messages for a given chat + // GetAllMembersByChatId gets all members for a given chat + // + // todo: Add paging when we introduce group chats + GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) + + // GetAllMessagesByChatId gets all messages for a given chat // // Note: Cursor is a message ID - GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) // PutChat creates a new chat PutChat(ctx context.Context, record *ChatRecord) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 68cce410..ff0220f5 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -399,6 +399,7 @@ type DatabaseData interface { GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error @@ -1472,12 +1473,15 @@ func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { return dp.chatv2.GetMessageById(ctx, chatId, messageId) } +func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { + return dp.chatv2.GetAllMembersByChatId(ctx, chatId) +} func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chatv2.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { return dp.chatv2.PutChat(ctx, record) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 5796e312..7e33bd96 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -299,8 +299,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) - // todo: We should also "flush" pointers for each chat member go s.flushMessages(ctx, chatId, owner, stream) + go s.flushPointers(ctx, chatId, stream) for { select { @@ -385,6 +385,49 @@ func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *c } } +func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushPointers", + "chat_id": chatId.String(), + }) + + memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatId) + if err == chat.ErrMemberNotFound { + return + } else if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + for _, memberRecord := range memberRecords { + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: &chatpb.Pointer{ + Kind: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } + } +} + func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { log := s.log.WithField("method", "SendMessage") log = client.InjectLoggingMetadata(ctx, log) From 35597d3fb2879474bbedd72a834fd345e1b6b91d Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 12:22:56 -0400 Subject: [PATCH 30/71] Address todos relating to copying pointer values in chat store --- pkg/code/data/chat/v2/id.go | 21 ++++++++ pkg/code/data/chat/v2/memory/store.go | 6 ++- pkg/code/data/chat/v2/model.go | 77 ++++++++++++++++++++------- pkg/pointer/pointer.go | 30 +++++++++++ 4 files changed, 112 insertions(+), 22 deletions(-) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index 4f7b7b0a..26341e6e 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -48,6 +48,13 @@ func (c ChatId) Validate() error { return nil } +// Clone clones a chat ID +func (c ChatId) Clone() ChatId { + var cloned ChatId + copy(cloned[:], c[:]) + return cloned +} + // String returns the string representation of a ChatId func (c ChatId) String() string { return hex.EncodeToString(c[:]) @@ -102,6 +109,13 @@ func (m MemberId) Validate() error { return nil } +// Clone clones a chat member ID +func (m MemberId) Clone() MemberId { + var cloned MemberId + copy(cloned[:], m[:]) + return cloned +} + // String returns the string representation of a MemberId func (m MemberId) String() string { return uuid.UUID(m).String() @@ -218,6 +232,13 @@ func (m MessageId) Validate() error { return nil } +// Clone clones a chat message ID +func (m MessageId) Clone() MessageId { + var cloned MessageId + copy(cloned[:], m[:]) + return cloned +} + // String returns the string representation of a MessageId func (m MessageId) String() string { return uuid.UUID(m).String() diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index ce1184b5..2c051fad 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -202,9 +202,11 @@ func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId c switch pointerType { case chat.PointerTypeDelivered: - item.DeliveryPointer = &pointer // todo: pointer copy safety + cloned := pointer.Clone() + item.DeliveryPointer = &cloned case chat.PointerTypeRead: - item.ReadPointer = &pointer // todo: pointer copy safety + cloned := pointer.Clone() + item.ReadPointer = &cloned } return nil diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index 39df38a8..f30b78c1 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -11,7 +11,7 @@ import ( chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" ) -type ChatType int +type ChatType uint8 const ( ChatTypeUnknown ChatType = iota @@ -20,7 +20,7 @@ const ( // ChatTypeGroup ) -type ReferenceType int +type ReferenceType uint8 const ( ReferenceTypeUnknown ReferenceType = iota @@ -28,7 +28,7 @@ const ( ReferenceTypeSignature ) -type PointerType int +type PointerType uint8 const ( PointerTypeUnknown PointerType = iota @@ -37,7 +37,7 @@ const ( PointerTypeRead ) -type Platform int +type Platform uint8 const ( PlatformUnknown Platform = iota @@ -256,6 +256,18 @@ func (r *MemberRecord) Validate() error { // Clone clones a member record func (r *MemberRecord) Clone() MemberRecord { + var deliveryPointerCopy *MessageId + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + deliveryPointerCopy = &cloned + } + + var readPointerCopy *MessageId + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + readPointerCopy = &cloned + } + return MemberRecord{ Id: r.Id, ChatId: r.ChatId, @@ -264,8 +276,8 @@ func (r *MemberRecord) Clone() MemberRecord { Platform: r.Platform, PlatformId: r.PlatformId, - DeliveryPointer: r.DeliveryPointer, // todo: pointer copy safety - ReadPointer: r.ReadPointer, // todo: pointer copy safety + DeliveryPointer: deliveryPointerCopy, + ReadPointer: readPointerCopy, IsMuted: r.IsMuted, IsUnsubscribed: r.IsUnsubscribed, @@ -283,8 +295,14 @@ func (r *MemberRecord) CopyTo(dst *MemberRecord) { dst.Platform = r.Platform dst.PlatformId = r.PlatformId - dst.DeliveryPointer = r.DeliveryPointer // todo: pointer copy safety - dst.ReadPointer = r.ReadPointer // todo: pointer copy safety + if r.DeliveryPointer != nil { + cloned := r.DeliveryPointer.Clone() + dst.DeliveryPointer = &cloned + } + if r.ReadPointer != nil { + cloned := r.ReadPointer.Clone() + dst.ReadPointer = &cloned + } dst.IsMuted = r.IsMuted dst.IsUnsubscribed = r.IsUnsubscribed @@ -350,21 +368,34 @@ func (r *MessageRecord) Validate() error { // Clone clones a message record func (r *MessageRecord) Clone() MessageRecord { + var senderCopy *MemberId + if r.Sender != nil { + cloned := r.Sender.Clone() + senderCopy = &cloned + } + + dataCopy := make([]byte, len(r.Data)) + copy(dataCopy, r.Data) + + var referenceTypeCopy *ReferenceType + if r.ReferenceType != nil { + cloned := *r.ReferenceType + referenceTypeCopy = &cloned + } + return MessageRecord{ Id: r.Id, ChatId: r.ChatId, MessageId: r.MessageId, - Sender: r.Sender, // todo: pointer copy safety + Sender: senderCopy, - Data: r.Data, // todo: pointer copy safety + Data: dataCopy, - ReferenceType: r.ReferenceType, // todo: pointer copy safety - Reference: r.Reference, // todo: pointer copy safety + ReferenceType: referenceTypeCopy, + Reference: pointer.StringCopy(r.Reference), IsSilent: r.IsSilent, - - // todo: finish implementing me } } @@ -374,16 +405,22 @@ func (r *MessageRecord) CopyTo(dst *MessageRecord) { dst.ChatId = r.ChatId dst.MessageId = r.MessageId - dst.Sender = r.Sender // todo: pointer copy safety + if r.Sender != nil { + cloned := r.Sender.Clone() + dst.Sender = &cloned + } - dst.Data = r.Data // todo: pointer copy safety + dataCopy := make([]byte, len(r.Data)) + copy(dataCopy, r.Data) + dst.Data = dataCopy - dst.ReferenceType = r.ReferenceType // todo: pointer copy safety - dst.Reference = r.Reference // todo: pointer copy safety + if r.ReferenceType != nil { + cloned := *r.ReferenceType + dst.ReferenceType = &cloned + } + dst.Reference = pointer.StringCopy(r.Reference) dst.IsSilent = r.IsSilent - - // todo: finish implementing me } // GetTimestamp gets the timestamp for a message record diff --git a/pkg/pointer/pointer.go b/pkg/pointer/pointer.go index a3f8da02..a353d347 100644 --- a/pkg/pointer/pointer.go +++ b/pkg/pointer/pointer.go @@ -32,6 +32,36 @@ func StringCopy(value *string) *string { return String(*value) } +// Uint8 returns a pointer to the provided uint8 value +func Uint8(value uint8) *uint8 { + return &value +} + +// Uint8OrDefault returns the pointer if not nil, otherwise the default value +func Uint8OrDefault(value *uint8, defaultValue uint8) *uint8 { + if value != nil { + return value + } + return &defaultValue +} + +// Uint8IfValid returns a pointer to the value if it's valid, otherwise nil +func Uint8IfValid(valid bool, value uint8) *uint8 { + if valid { + return &value + } + return nil +} + +// Uint8Copy returns a pointer that's a copy of the provided value +func Uint8Copy(value *uint8) *uint8 { + if value == nil { + return nil + } + + return Uint8(*value) +} + // Uint64 returns a pointer to the provided uint64 value func Uint64(value uint64) *uint64 { return &value From da259356ec4fa4dd44a2881270039894fe1049a7 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 14:34:59 -0400 Subject: [PATCH 31/71] Fix result codes in GetMessages --- pkg/code/server/grpc/chat/v2/server.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 7e33bd96..775d848f 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -127,7 +127,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest case nil: case chat.ErrChatNotFound: return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + Result: chatpb.GetMessagesResponse_CHAT_NOT_FOUND, }, nil default: log.WithError(err).Warn("failure getting chat record") @@ -174,7 +174,11 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest query.WithDirection(direction), query.WithLimit(limit), ) - if err != nil { + if err == chat.ErrMessageNotFound { + return &chatpb.GetMessagesResponse{ + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, + }, nil + } else if err != nil { log.WithError(err).Warn("failure getting chat messages") return nil, status.Error(codes.Internal, "") } From 24268d6c6294f581a37428b688bdb77436cd047c Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 15:24:28 -0400 Subject: [PATCH 32/71] Setup a temporary mock chat for testing --- pkg/code/data/chat/v2/id.go | 30 +++++++++++++++++++++ pkg/code/server/grpc/chat/v2/server.go | 36 ++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index 26341e6e..de912e48 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -29,6 +29,16 @@ func GetChatIdFromBytes(buffer []byte) (ChatId, error) { return typed, nil } +// GetChatIdFromBytes gets a chat ID from the string representation +func GetChatIdFromString(value string) (ChatId, error) { + decoded, err := hex.DecodeString(value) + if err != nil { + return ChatId{}, errors.Wrap(err, "value is not a hexadecimal string") + } + + return GetChatIdFromBytes(decoded) +} + // GetChatIdFromProto gets a chat ID from the protobuf variant func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { if err := proto.Validate(); err != nil { @@ -84,6 +94,16 @@ func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { return typed, nil } +// GetMemberIdFromString gets a chat member ID from the string representation +func GetMemberIdFromString(value string) (MemberId, error) { + decoded, err := uuid.Parse(value) + if err != nil { + return MemberId{}, errors.Wrap(err, "value is not a uuid string") + } + + return GetMemberIdFromBytes(decoded[:]) +} + // GetMemberIdFromProto gets a member ID from the protobuf variant func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { if err := proto.Validate(); err != nil { @@ -171,6 +191,16 @@ func GetMessageIdFromBytes(buffer []byte) (MessageId, error) { return typed, nil } +// GetMessageIdFromString gets a chat message ID from the string representation +func GetMessageIdFromString(value string) (MessageId, error) { + decoded, err := uuid.Parse(value) + if err != nil { + return MessageId{}, errors.Wrap(err, "value is not a uuid string") + } + + return GetMessageIdFromBytes(decoded[:]) +} + // GetMessageIdFromProto gets a message ID from the protobuf variant func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { if err := proto.Validate(); err != nil { diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 775d848f..a10f2ef9 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -68,9 +68,45 @@ func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier go s.asyncChatEventStreamNotifier(i, channel) } + // todo: Remove when testing is complete + s.setupMockChat() + return s } +func (s *server) setupMockChat() { + ctx := context.Background() + + chatId, _ := chat.GetChatIdFromString("c355fcec8c521e7937d45283d83bbfc63a0c688004f2386a535fc817218f917b") + chatRecord := &chat.ChatRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + IsVerified: true, + CreatedAt: time.Now(), + } + s.data.PutChatV2(ctx, chatRecord) + + memberId1, _ := chat.GetMemberIdFromString("034dda45-b4c2-45db-b1da-181298898a16") + memberRecord1 := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId1, + Platform: chat.PlatformCode, + PlatformId: "8bw4gaRQk91w7vtgTN4E12GnKecY2y6CjPai7WUvWBQ8", + JoinedAt: time.Now(), + } + s.data.PutChatMemberV2(ctx, memberRecord1) + + memberId2, _ := chat.GetMemberIdFromString("a9d27058-f2d8-4034-bf52-b20c09a670de") + memberRecord2 := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId2, + Platform: chat.PlatformCode, + PlatformId: "EDknQfoUnj73L56vKtEc6Qqw5VoHaF32eHYdz3V4y27M", + JoinedAt: time.Now(), + } + s.data.PutChatMemberV2(ctx, memberRecord2) +} + func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") log = client.InjectLoggingMetadata(ctx, log) From f82841161fbeaf67fac7a4e3ac115000cb03b0bc Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 12 Jun 2024 09:47:28 -0400 Subject: [PATCH 33/71] Fix message ID sorting --- pkg/code/data/chat/v2/model.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index f30b78c1..6d5124b0 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -102,7 +102,7 @@ type MessagesById []*MessageRecord func (a MessagesById) Len() int { return len(a) } func (a MessagesById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a MessagesById) Less(i, j int) bool { - return a[i].MessageId.Before(a[i].MessageId) + return a[i].MessageId.Before(a[j].MessageId) } // GetChatIdFromProto gets a chat ID from the protobuf variant From e8e873f3d4d74ba7eb618a6ce494fbac29ac18ce Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 12 Jun 2024 10:22:22 -0400 Subject: [PATCH 34/71] flushMessages doesn't need a cursor value for the DB query --- pkg/code/server/grpc/chat/v2/server.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index a10f2ef9..00f7749d 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -395,13 +395,11 @@ func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *c "owner_account": owner.PublicKey().ToBase58(), }) - cursorValue := chat.GenerateMessageIdAtTime(time.Now().Add(2 * time.Second)) - protoChatMessages, err := s.getProtoChatMessages( ctx, chatId, owner, - query.WithCursor(cursorValue[:]), + query.WithCursor(query.EmptyCursor), query.WithDirection(query.Descending), query.WithLimit(flushMessageCount), ) From 932a60eda6474cbb5988d6b976e0b5508db7c166 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 13 Jun 2024 10:40:30 -0400 Subject: [PATCH 35/71] Implement GetChats RPC limited to anonymous chat membership --- pkg/code/data/chat/v2/memory/store.go | 101 ++++++++++++++++- pkg/code/data/chat/v2/model.go | 76 ++++++++++++- pkg/code/data/chat/v2/store.go | 7 ++ pkg/code/data/internal.go | 12 ++ pkg/code/server/grpc/chat/v2/server.go | 148 ++++++++++++++++++++++++- 5 files changed, 336 insertions(+), 8 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index 2c051fad..97222fab 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -80,6 +80,34 @@ func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([] return cloneMemberRecords(items), nil } +// GetAllMembersByPlatformId implements chat.store.GetAllMembersByPlatformId +func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findMembersByPlatformId(platform, platformId) + items, err := s.getMemberRecordPage(items, cursor, direction, limit) + if err != nil { + return nil, err + } + + if len(items) == 0 { + return nil, chat.ErrMemberNotFound + } + return cloneMemberRecords(items), nil +} + +// GetUnreadCount implements chat.store.GetUnreadCount +func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, readPointer chat.MessageId) (uint32, error) { + s.mu.Lock() + defer s.mu.Unlock() + + items := s.findMessagesByChatId(chatId) + items = s.filterMessagesAfter(items, readPointer) + items = s.filterNotifiedMessages(items) + return uint32(len(items)), nil +} + // GetAllMessagesByChatId implements chat.Store.GetAllMessagesByChatId func (s *store) GetAllMessagesByChatId(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { s.mu.Lock() @@ -296,6 +324,55 @@ func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { return res } +func (s *store) findMembersByPlatformId(platform chat.Platform, platformId string) []*chat.MemberRecord { + var res []*chat.MemberRecord + for _, item := range s.memberRecords { + if platform == item.Platform && platformId == item.PlatformId { + res = append(res, item) + } + } + return res +} + +func (s *store) getMemberRecordPage(items []*chat.MemberRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { + if len(items) == 0 { + return nil, nil + } + + var memberIdCursor *uint64 + if len(cursor) > 0 { + cursorValue := query.FromCursor(cursor) + memberIdCursor = &cursorValue + } + + var res []*chat.MemberRecord + if memberIdCursor == nil { + res = items + } else { + for _, item := range items { + if item.Id > int64(*memberIdCursor) && direction == query.Ascending { + res = append(res, item) + } + + if item.Id < int64(*memberIdCursor) && direction == query.Descending { + res = append(res, item) + } + } + } + + if direction == query.Ascending { + sort.Sort(chat.MembersById(res)) + } else { + sort.Sort(sort.Reverse(chat.MembersById(res))) + } + + if len(res) >= int(limit) { + return res[:limit], nil + } + + return res, nil +} + func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord { for _, item := range s.messageRecords { if data.Id == item.Id { @@ -328,6 +405,26 @@ func (s *store) findMessagesByChatId(chatId chat.ChatId) []*chat.MessageRecord { return res } +func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.MessageId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if item.MessageId.After(pointer) { + res = append(res, item) + } + } + return res +} + +func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if !item.IsSilent { + res = append(res, item) + } + } + return res +} + func (s *store) getMessageRecordPage(items []*chat.MessageRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { if len(items) == 0 { return nil, nil @@ -358,9 +455,9 @@ func (s *store) getMessageRecordPage(items []*chat.MessageRecord, cursor query.C } if direction == query.Ascending { - sort.Sort(chat.MessagesById(res)) + sort.Sort(chat.MessagesByMessageId(res)) } else { - sort.Sort(sort.Reverse(chat.MessagesById(res))) + sort.Sort(sort.Reverse(chat.MessagesByMessageId(res))) } if len(res) >= int(limit) { diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index 6d5124b0..9a7108f3 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -97,15 +97,59 @@ type MessageRecord struct { // Note: No timestamp field, since it's encoded in MessageId } -type MessagesById []*MessageRecord +type MembersById []*MemberRecord -func (a MessagesById) Len() int { return len(a) } -func (a MessagesById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a MessagesById) Less(i, j int) bool { +func (a MembersById) Len() int { return len(a) } +func (a MembersById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MembersById) Less(i, j int) bool { + return a[i].Id < a[j].Id +} + +type MessagesByMessageId []*MessageRecord + +func (a MessagesByMessageId) Len() int { return len(a) } +func (a MessagesByMessageId) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MessagesByMessageId) Less(i, j int) bool { return a[i].MessageId.Before(a[j].MessageId) } -// GetChatIdFromProto gets a chat ID from the protobuf variant +// GetChatTypeFromProto gets a chat type from the protobuf variant +func GetChatTypeFromProto(proto chatpb.ChatMetadata_Kind) ChatType { + switch proto { + case chatpb.ChatMetadata_NOTIFICATION: + return ChatTypeNotification + case chatpb.ChatMetadata_TWO_WAY: + return ChatTypeTwoWay + default: + return ChatTypeUnknown + } +} + +// ToProto returns the proto representation of the chat type +func (c ChatType) ToProto() chatpb.ChatMetadata_Kind { + switch c { + case ChatTypeNotification: + return chatpb.ChatMetadata_NOTIFICATION + case ChatTypeTwoWay: + return chatpb.ChatMetadata_TWO_WAY + default: + return chatpb.ChatMetadata_UNKNOWN + } +} + +// String returns the string representation of the chat type +func (c ChatType) String() string { + switch c { + case ChatTypeNotification: + return "notification" + case ChatTypeTwoWay: + return "two-way" + default: + return "unknown" + } +} + +// GetPointerTypeFromProto gets a chat ID from the protobuf variant func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { switch proto { case chatpb.Pointer_SENT: @@ -147,6 +191,28 @@ func (p PointerType) String() string { } } +// ToProto returns the proto representation of the platform +func (p Platform) ToProto() chatpb.ChatMemberIdentity_Platform { + switch p { + case PlatformTwitter: + return chatpb.ChatMemberIdentity_TWITTER + default: + return chatpb.ChatMemberIdentity_UNKNOWN + } +} + +// String returns the string representation of the platform +func (p Platform) String() string { + switch p { + case PlatformCode: + return "code" + case PlatformTwitter: + return "twitter" + default: + return "unknown" + } +} + // Validate validates a chat Record func (r *ChatRecord) Validate() error { if err := r.ChatId.Validate(); err != nil { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 7ecc10fc..5a0e9f20 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -33,11 +33,18 @@ type Store interface { // todo: Add paging when we introduce group chats GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) + // GetAllMembersByPlatformId gets all members for a given platform user across + // all chats + GetAllMembersByPlatformId(ctx context.Context, platform Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + // GetAllMessagesByChatId gets all messages for a given chat // // Note: Cursor is a message ID GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + // GetUnreadCount gets the unread message count for a chat ID at a read pointer + GetUnreadCount(ctx context.Context, chatId ChatId, readPointer MessageId) (uint32, error) + // PutChat creates a new chat PutChat(ctx context.Context, record *ChatRecord) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index ff0220f5..68ab5816 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -400,7 +400,9 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) + GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error @@ -1476,6 +1478,13 @@ func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId cha func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { return dp.chatv2.GetAllMembersByChatId(ctx, chatId) } +func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } + return dp.chatv2.GetAllMembersByPlatformId(ctx, platform, platformId, req.Cursor, req.SortBy, req.Limit) +} func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { @@ -1483,6 +1492,9 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetUnreadCount(ctx, chatId, readPointer) +} func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { return dp.chatv2.PutChat(ctx, record) } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 00f7749d..d0448565 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1,8 +1,10 @@ package chat_v2 import ( + "bytes" "context" "fmt" + "math" "sync" "time" @@ -31,6 +33,7 @@ import ( ) const ( + maxGetChatsPageSize = 100 maxGetMessagesPageSize = 100 flushMessageCount = 100 ) @@ -107,6 +110,7 @@ func (s *server) setupMockChat() { s.data.PutChatMemberV2(ctx, memberRecord2) } +// todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") log = client.InjectLoggingMetadata(ctx, log) @@ -124,7 +128,149 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch return nil, err } - return nil, status.Error(codes.Unimplemented, "") + var limit uint64 + if req.PageSize > 0 { + limit = uint64(req.PageSize) + } else { + limit = maxGetChatsPageSize + } + if limit > maxGetChatsPageSize { + limit = maxGetChatsPageSize + } + + var direction query.Ordering + if req.Direction == chatpb.GetChatsRequest_ASC { + direction = query.Ascending + } else { + direction = query.Descending + } + + var cursor query.Cursor + if req.Cursor != nil { + cursor = req.Cursor.Value + } else { + cursor = query.ToCursor(0) + if direction == query.Descending { + cursor = query.ToCursor(math.MaxInt64 - 1) + } + } + + patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( + ctx, + chat.PlatformCode, // todo: support other platforms once we support revealing identity + owner.PublicKey().ToBase58(), + query.WithCursor(cursor), + query.WithDirection(direction), + query.WithLimit(limit), + ) + if err == chat.ErrMemberNotFound { + return &chatpb.GetChatsResponse{ + Result: chatpb.GetChatsResponse_NOT_FOUND, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting chat members for platform user") + return nil, status.Error(codes.Internal, "") + } + + var protoChats []*chatpb.ChatMetadata + for _, platformUserMemberRecord := range patformUserMemberRecords { + log := log.WithField("chat_id", platformUserMemberRecord.ChatId.String()) + + chatRecord, err := s.data.GetChatByIdV2(ctx, platformUserMemberRecord.ChatId) + if err != nil { + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + protoChat := &chatpb.ChatMetadata{ + ChatId: chatRecord.ChatId.ToProto(), + Kind: chatRecord.ChatType.ToProto(), + + IsMuted: platformUserMemberRecord.IsMuted, + IsSubscribed: !platformUserMemberRecord.IsUnsubscribed, + + Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))}, + } + + // Unread count calculations can be skipped for unsubscribed chats. They + // don't appear in chat history. + skipUnreadCountQuery := platformUserMemberRecord.IsUnsubscribed + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + protoChat.Title = "Mock Chat" // todo: proper title with localization + + protoChat.CanMute = true + protoChat.CanUnsubscribe = true + default: + return nil, status.Errorf(codes.Unimplemented, "unsupported chat type: %s", chatRecord.ChatType.String()) + } + + chatMemberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return nil, status.Error(codes.Internal, "") + } + for _, memberRecord := range chatMemberRecords { + var identity *chatpb.ChatMemberIdentity + switch memberRecord.Platform { + case chat.PlatformCode: + case chat.PlatformTwitter: + identity = &chatpb.ChatMemberIdentity{ + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + } + default: + return nil, status.Errorf(codes.Unimplemented, "unsupported platform type: %s", memberRecord.Platform.String()) + } + + var pointers []*chatpb.Pointer + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + pointers = append(pointers, &chatpb.Pointer{ + Kind: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }) + } + + protoChat.Members = append(protoChat.Members, &chatpb.ChatMember{ + MemberId: memberRecord.MemberId.ToProto(), + IsSelf: bytes.Equal(memberRecord.MemberId[:], platformUserMemberRecord.MemberId[:]), + Identity: identity, + Pointers: pointers, + }) + } + + if !skipUnreadCountQuery { + readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) + if platformUserMemberRecord.ReadPointer != nil { + readPointer = *platformUserMemberRecord.ReadPointer + } + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, readPointer) + if err != nil { + log.WithError(err).Warn("failure getting unread count") + return nil, status.Error(codes.Internal, "") + } + protoChat.NumUnread = unreadCount + } + + protoChats = append(protoChats, protoChat) + } + + return &chatpb.GetChatsResponse{ + Result: chatpb.GetChatsResponse_OK, + Chats: protoChats, + }, nil } func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { From 712a1f260a72f4d3fb0a66c4642c393b0705eec4 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 13 Jun 2024 12:20:40 -0400 Subject: [PATCH 36/71] Remove addressed todo --- pkg/code/server/grpc/chat/v2/server.go | 1 - 1 file changed, 1 deletion(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index d0448565..f1d24916 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -38,7 +38,6 @@ const ( flushMessageCount = 100 ) -// todo: Ensure all relevant logging fields are set type server struct { log *logrus.Entry From 47df62819ec73845a212f1401b65fb46176393a6 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 13 Jun 2024 15:39:14 -0400 Subject: [PATCH 37/71] Have chat v2 store indicate if pointer was advanced --- pkg/code/data/chat/v2/memory/store.go | 29 +++++++++++++------------- pkg/code/data/chat/v2/store.go | 2 +- pkg/code/data/internal.go | 4 ++-- pkg/code/server/grpc/chat/v2/server.go | 19 +++++++++-------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index 97222fab..fc17e977 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -201,11 +201,11 @@ func (s *store) PutMessage(_ context.Context, record *chat.MessageRecord) error } // AdvancePointer implements chat.Store.AdvancePointer -func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) error { +func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) (bool, error) { switch pointerType { case chat.PointerTypeDelivered, chat.PointerTypeRead: default: - return chat.ErrInvalidPointerType + return false, chat.ErrInvalidPointerType } s.mu.Lock() @@ -213,7 +213,7 @@ func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId c item := s.findMemberById(chatId, memberId) if item == nil { - return chat.ErrMemberNotFound + return false, chat.ErrMemberNotFound } var currentPointer *chat.MessageId @@ -224,20 +224,19 @@ func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId c currentPointer = item.ReadPointer } - if currentPointer != nil && currentPointer.After(pointer) { - return nil - } + if currentPointer == nil || currentPointer.Before(pointer) { + switch pointerType { + case chat.PointerTypeDelivered: + cloned := pointer.Clone() + item.DeliveryPointer = &cloned + case chat.PointerTypeRead: + cloned := pointer.Clone() + item.ReadPointer = &cloned + } - switch pointerType { - case chat.PointerTypeDelivered: - cloned := pointer.Clone() - item.DeliveryPointer = &cloned - case chat.PointerTypeRead: - cloned := pointer.Clone() - item.ReadPointer = &cloned + return true, nil } - - return nil + return false, nil } // SetMuteState implements chat.Store.SetMuteState diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 5a0e9f20..82e8150f 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -55,7 +55,7 @@ type Store interface { PutMessage(ctx context.Context, record *MessageRecord) error // AdvancePointer advances a chat pointer for a chat member - AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) error + AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) // SetMuteState updates the mute state for a chat member SetMuteState(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 68ab5816..c4a06397 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -406,7 +406,7 @@ type DatabaseData interface { PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error - AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error + AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error @@ -1504,7 +1504,7 @@ func (dp *DatabaseProvider) PutChatMemberV2(ctx context.Context, record *chat_v2 func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { return dp.chatv2.PutMessage(ctx, record) } -func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) error { +func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) } func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index f1d24916..47287f23 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -851,20 +851,21 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.Internal, "") } - // Note: Guarantees that pointer will never be advanced to some point in the past - err = s.data.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointerValue) + isAdvanced, err := s.data.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointerValue) if err != nil { log.WithError(err).Warn("failure advancing chat pointer") return nil, status.Error(codes.Internal, "") } - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Pointer{ - Pointer: req.Pointer, - }, - } - if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") + if isAdvanced { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: req.Pointer, + }, + } + if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } } return &chatpb.AdvancePointerResponse{ From 3e22628cc05cecebcd67c3562b226880ccf441a6 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 13 Jun 2024 16:54:07 -0400 Subject: [PATCH 38/71] Fix unread count to not count messages sent by the reader --- pkg/code/data/chat/v2/memory/store.go | 13 ++++++++++++- pkg/code/data/chat/v2/store.go | 4 ++-- pkg/code/data/internal.go | 6 +++--- pkg/code/server/grpc/chat/v2/server.go | 2 +- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index fc17e977..f11fc44d 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -98,12 +98,13 @@ func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platf } // GetUnreadCount implements chat.store.GetUnreadCount -func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, readPointer chat.MessageId) (uint32, error) { +func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) { s.mu.Lock() defer s.mu.Unlock() items := s.findMessagesByChatId(chatId) items = s.filterMessagesAfter(items, readPointer) + items = s.filterMessagesNotSentBy(items, memberId) items = s.filterNotifiedMessages(items) return uint32(len(items)), nil } @@ -414,6 +415,16 @@ func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.Me return res } +func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) { + res = append(res, item) + } + } + return res +} + func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { var res []*chat.MessageRecord for _, item := range items { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 82e8150f..f58d34b4 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -42,8 +42,8 @@ type Store interface { // Note: Cursor is a message ID GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) - // GetUnreadCount gets the unread message count for a chat ID at a read pointer - GetUnreadCount(ctx context.Context, chatId ChatId, readPointer MessageId) (uint32, error) + // GetUnreadCount gets the unread message count for a chat ID at a read pointer for a given chat member + GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error) // PutChat creates a new chat PutChat(ctx context.Context, record *ChatRecord) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index c4a06397..93248a4c 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -402,7 +402,7 @@ type DatabaseData interface { GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) - GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error @@ -1492,8 +1492,8 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) { - return dp.chatv2.GetUnreadCount(ctx, chatId, readPointer) +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer) } func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { return dp.chatv2.PutChat(ctx, record) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 47287f23..f71b7f8e 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -255,7 +255,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch if platformUserMemberRecord.ReadPointer != nil { readPointer = *platformUserMemberRecord.ReadPointer } - unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, readPointer) + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, platformUserMemberRecord.MemberId, readPointer) if err != nil { log.WithError(err).Warn("failure getting unread count") return nil, status.Error(codes.Internal, "") From 8a4c36e42642773866c9a4979bc3c9bdcb2462f3 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 17 Jun 2024 13:23:58 -0400 Subject: [PATCH 39/71] Fix build with refactor changes to chat protos --- pkg/code/data/chat/v2/model.go | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index 9a7108f3..e0cb5ed9 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -114,11 +114,11 @@ func (a MessagesByMessageId) Less(i, j int) bool { } // GetChatTypeFromProto gets a chat type from the protobuf variant -func GetChatTypeFromProto(proto chatpb.ChatMetadata_Kind) ChatType { +func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { switch proto { - case chatpb.ChatMetadata_NOTIFICATION: + case chatpb.ChatType_NOTIFICATION: return ChatTypeNotification - case chatpb.ChatMetadata_TWO_WAY: + case chatpb.ChatType_TWO_WAY: return ChatTypeTwoWay default: return ChatTypeUnknown @@ -126,14 +126,14 @@ func GetChatTypeFromProto(proto chatpb.ChatMetadata_Kind) ChatType { } // ToProto returns the proto representation of the chat type -func (c ChatType) ToProto() chatpb.ChatMetadata_Kind { +func (c ChatType) ToProto() chatpb.ChatType { switch c { case ChatTypeNotification: - return chatpb.ChatMetadata_NOTIFICATION + return chatpb.ChatType_NOTIFICATION case ChatTypeTwoWay: - return chatpb.ChatMetadata_TWO_WAY + return chatpb.ChatType_TWO_WAY default: - return chatpb.ChatMetadata_UNKNOWN + return chatpb.ChatType_UNKNOWN_CHAT_TYPE } } @@ -150,13 +150,13 @@ func (c ChatType) String() string { } // GetPointerTypeFromProto gets a chat ID from the protobuf variant -func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { +func GetPointerTypeFromProto(proto chatpb.PointerType) PointerType { switch proto { - case chatpb.Pointer_SENT: + case chatpb.PointerType_SENT: return PointerTypeSent - case chatpb.Pointer_DELIVERED: + case chatpb.PointerType_DELIVERED: return PointerTypeDelivered - case chatpb.Pointer_READ: + case chatpb.PointerType_READ: return PointerTypeRead default: return PointerTypeUnknown @@ -164,16 +164,16 @@ func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { } // ToProto returns the proto representation of the pointer type -func (p PointerType) ToProto() chatpb.Pointer_Kind { +func (p PointerType) ToProto() chatpb.PointerType { switch p { case PointerTypeSent: - return chatpb.Pointer_SENT + return chatpb.PointerType_SENT case PointerTypeDelivered: - return chatpb.Pointer_DELIVERED + return chatpb.PointerType_DELIVERED case PointerTypeRead: - return chatpb.Pointer_READ + return chatpb.PointerType_READ default: - return chatpb.Pointer_UNKNOWN + return chatpb.PointerType_UNKNOWN_POINTER_TYPE } } From a3c5bdc3245f2df0bd2f22d15c22acce849843ae Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 18 Jun 2024 11:55:46 -0400 Subject: [PATCH 40/71] Initial implementation of StartChat that always starts a new chat --- pkg/code/data/chat/v2/memory/store.go | 15 +- pkg/code/data/chat/v2/store.go | 5 +- pkg/code/data/internal.go | 6 +- pkg/code/server/grpc/chat/v2/server.go | 463 ++++++++++++++++++------- 4 files changed, 347 insertions(+), 142 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index f11fc44d..905872f3 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -80,12 +80,12 @@ func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([] return cloneMemberRecords(items), nil } -// GetAllMembersByPlatformId implements chat.store.GetAllMembersByPlatformId -func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { +// GetAllMembersByPlatformIds implements chat.store.GetAllMembersByPlatformIds +func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[chat.Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { s.mu.Lock() defer s.mu.Unlock() - items := s.findMembersByPlatformId(platform, platformId) + items := s.findMembersByPlatformIds(idByPlatform) items, err := s.getMemberRecordPage(items, cursor, direction, limit) if err != nil { return nil, err @@ -324,10 +324,15 @@ func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { return res } -func (s *store) findMembersByPlatformId(platform chat.Platform, platformId string) []*chat.MemberRecord { +func (s *store) findMembersByPlatformIds(idByPlatform map[chat.Platform]string) []*chat.MemberRecord { var res []*chat.MemberRecord for _, item := range s.memberRecords { - if platform == item.Platform && platformId == item.PlatformId { + platformId, ok := idByPlatform[item.Platform] + if !ok { + continue + } + + if platformId == item.PlatformId { res = append(res, item) } } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index f58d34b4..957260e8 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -33,9 +33,8 @@ type Store interface { // todo: Add paging when we introduce group chats GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) - // GetAllMembersByPlatformId gets all members for a given platform user across - // all chats - GetAllMembersByPlatformId(ctx context.Context, platform Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + // GetAllMembersByPlatformIds gets all members for platform users across all chats + GetAllMembersByPlatformIds(ctx context.Context, idByPlatform map[Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) // GetAllMessagesByChatId gets all messages for a given chat // diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 93248a4c..11808d3e 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -400,7 +400,7 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) - GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) + GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error @@ -1478,12 +1478,12 @@ func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId cha func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { return dp.chatv2.GetAllMembersByChatId(ctx, chatId) } -func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { +func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chatv2.GetAllMembersByPlatformId(ctx, platform, platformId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv2.GetAllMembersByPlatformIds(ctx, idByPlatform, req.Cursor, req.SortBy, req.Limit) } func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { req, err := query.DefaultPaginationHandler(opts...) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index f71b7f8e..e864ca02 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1,13 +1,15 @@ package chat_v2 import ( - "bytes" "context" + "crypto/rand" + "database/sql" "fmt" "math" "sync" "time" + "github.com/mr-tron/base58" "github.com/pkg/errors" "github.com/sirupsen/logrus" "golang.org/x/text/language" @@ -19,11 +21,13 @@ import ( chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" auth_util "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/data/twitter" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/database/query" @@ -70,45 +74,9 @@ func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier go s.asyncChatEventStreamNotifier(i, channel) } - // todo: Remove when testing is complete - s.setupMockChat() - return s } -func (s *server) setupMockChat() { - ctx := context.Background() - - chatId, _ := chat.GetChatIdFromString("c355fcec8c521e7937d45283d83bbfc63a0c688004f2386a535fc817218f917b") - chatRecord := &chat.ChatRecord{ - ChatId: chatId, - ChatType: chat.ChatTypeTwoWay, - IsVerified: true, - CreatedAt: time.Now(), - } - s.data.PutChatV2(ctx, chatRecord) - - memberId1, _ := chat.GetMemberIdFromString("034dda45-b4c2-45db-b1da-181298898a16") - memberRecord1 := &chat.MemberRecord{ - ChatId: chatId, - MemberId: memberId1, - Platform: chat.PlatformCode, - PlatformId: "8bw4gaRQk91w7vtgTN4E12GnKecY2y6CjPai7WUvWBQ8", - JoinedAt: time.Now(), - } - s.data.PutChatMemberV2(ctx, memberRecord1) - - memberId2, _ := chat.GetMemberIdFromString("a9d27058-f2d8-4034-bf52-b20c09a670de") - memberRecord2 := &chat.MemberRecord{ - ChatId: chatId, - MemberId: memberId2, - Platform: chat.PlatformCode, - PlatformId: "EDknQfoUnj73L56vKtEc6Qqw5VoHaF32eHYdz3V4y27M", - JoinedAt: time.Now(), - } - s.data.PutChatMemberV2(ctx, memberRecord2) -} - // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") @@ -154,10 +122,17 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } + myIdentities, err := s.getAllIdentities(ctx, owner) + if err != nil { + log.WithError(err).Warn("failure getting identities for owner account") + return nil, status.Error(codes.Internal, "") + } + + // todo: Use a better query that returns chat IDs. This will result in duplicate + // chat results if the user is in the chat multiple times across many identities. patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( ctx, - chat.PlatformCode, // todo: support other platforms once we support revealing identity - owner.PublicKey().ToBase58(), + myIdentities, query.WithCursor(cursor), query.WithDirection(direction), query.WithLimit(limit), @@ -181,87 +156,18 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch return nil, status.Error(codes.Internal, "") } - protoChat := &chatpb.ChatMetadata{ - ChatId: chatRecord.ChatId.ToProto(), - Kind: chatRecord.ChatType.ToProto(), - - IsMuted: platformUserMemberRecord.IsMuted, - IsSubscribed: !platformUserMemberRecord.IsUnsubscribed, - - Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))}, - } - - // Unread count calculations can be skipped for unsubscribed chats. They - // don't appear in chat history. - skipUnreadCountQuery := platformUserMemberRecord.IsUnsubscribed - - switch chatRecord.ChatType { - case chat.ChatTypeTwoWay: - protoChat.Title = "Mock Chat" // todo: proper title with localization - - protoChat.CanMute = true - protoChat.CanUnsubscribe = true - default: - return nil, status.Errorf(codes.Unimplemented, "unsupported chat type: %s", chatRecord.ChatType.String()) - } - - chatMemberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) + memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting chat members") return nil, status.Error(codes.Internal, "") } - for _, memberRecord := range chatMemberRecords { - var identity *chatpb.ChatMemberIdentity - switch memberRecord.Platform { - case chat.PlatformCode: - case chat.PlatformTwitter: - identity = &chatpb.ChatMemberIdentity{ - Platform: memberRecord.Platform.ToProto(), - Username: memberRecord.PlatformId, - } - default: - return nil, status.Errorf(codes.Unimplemented, "unsupported platform type: %s", memberRecord.Platform.String()) - } - var pointers []*chatpb.Pointer - for _, optionalPointer := range []struct { - kind chat.PointerType - value *chat.MessageId - }{ - {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, - {chat.PointerTypeRead, memberRecord.ReadPointer}, - } { - if optionalPointer.value == nil { - continue - } - - pointers = append(pointers, &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberRecord.MemberId.ToProto(), - }) - } - - protoChat.Members = append(protoChat.Members, &chatpb.ChatMember{ - MemberId: memberRecord.MemberId.ToProto(), - IsSelf: bytes.Equal(memberRecord.MemberId[:], platformUserMemberRecord.MemberId[:]), - Identity: identity, - Pointers: pointers, - }) - } - - if !skipUnreadCountQuery { - readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) - if platformUserMemberRecord.ReadPointer != nil { - readPointer = *platformUserMemberRecord.ReadPointer - } - unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, platformUserMemberRecord.MemberId, readPointer) - if err != nil { - log.WithError(err).Warn("failure getting unread count") - return nil, status.Error(codes.Internal, "") - } - protoChat.NumUnread = unreadCount + protoChat, err := s.toProtoChat(ctx, chatRecord, memberRecords, myIdentities) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") } + protoChat.Cursor = &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))} protoChats = append(protoChats, protoChat) } @@ -387,7 +293,7 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } if req.GetOpenStream() == nil { - return status.Error(codes.InvalidArgument, "open_stream is nil") + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") } owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) @@ -414,7 +320,7 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e signature := req.GetOpenStream().Signature req.GetOpenStream().Signature = nil if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { - return err + // return err } _, err = s.data.GetChatByIdV2(ctx, chatId) @@ -611,6 +517,181 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * } } +func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { + log := s.log.WithField("method", "SendMessage") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + signature := req.Signature + req.Signature = nil + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + switch typed := req.Parameters.(type) { + case *chatpb.StartChatRequest_TipChat: + intentId := base58.Encode(typed.TipChat.IntentId.Value) + log = log.WithField("intent", intentId) + + intentRecord, err := s.data.GetIntent(ctx, intentId) + if err == intent.ErrIntentNotFound { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting intent record") + return nil, status.Error(codes.Internal, "") + } + + // The intent was not for a tip. + if intentRecord.SendPrivatePaymentMetadata == nil || !intentRecord.SendPrivatePaymentMetadata.IsTip { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + tipper, err := common.NewAccountFromPublicKeyString(intentRecord.InitiatorOwnerAccount) + if err != nil { + log.WithError(err).Warn("invalid tipper owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("tipper", tipper.PublicKey().ToBase58()) + + tippee, err := common.NewAccountFromPublicKeyString(intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount) + if err != nil { + log.WithError(err).Warn("invalid tippee owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("tippee", tippee.PublicKey().ToBase58()) + + // For now, don't allow chats where you tipped yourself. + // + // todo: How do we want to handle this case? + if owner.PublicKey().ToBase58() == tipper.PublicKey().ToBase58() { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + // Only the owner of the platform user at the time of tipping can initiate the chat. + if owner.PublicKey().ToBase58() != tippee.PublicKey().ToBase58() { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_DENIED, + Chat: nil, + }, nil + } + + // todo: This will require a refactor when we allow creation of other types of chats + switch intentRecord.SendPrivatePaymentMetadata.TipMetadata.Platform { + case transactionpb.TippedUser_TWITTER: + twitterUsername := intentRecord.SendPrivatePaymentMetadata.TipMetadata.Username + + // The owner must still own the Twitter username + ownsUsername, err := s.ownsTwitterUsername(ctx, owner, twitterUsername) + if err != nil { + log.WithError(err).Warn("failure determing twitter username ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsUsername { + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_DENIED, + }, nil + } + + // todo: try to find an existing chat, but for now always create a new completely random one + var chatId chat.ChatId + rand.Read(chatId[:]) + + creationTs := time.Now() + + chatRecord := &chat.ChatRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + + IsVerified: true, + + CreatedAt: creationTs, + } + + memberRecords := []*chat.MemberRecord{ + { + ChatId: chatId, + MemberId: chat.GenerateMemberId(), + + Platform: chat.PlatformTwitter, + PlatformId: twitterUsername, + + JoinedAt: creationTs, + }, + { + ChatId: chatId, + MemberId: chat.GenerateMemberId(), + + Platform: chat.PlatformCode, + PlatformId: tipper.PublicKey().ToBase58(), + + JoinedAt: creationTs, + }, + } + + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + err := s.data.PutChatV2(ctx, chatRecord) + if err != nil { + return errors.Wrap(err, "error creating chat record") + } + + for _, memberRecord := range memberRecords { + err := s.data.PutChatMemberV2(ctx, memberRecord) + if err != nil { + return errors.Wrap(err, "error creating member record") + } + } + + return nil + }) + if err != nil { + log.WithError(err).Warn("failure creating new chat") + return nil, status.Error(codes.Internal, "") + } + + protoChat, err := s.toProtoChat( + ctx, + chatRecord, + memberRecords, + map[chat.Platform]string{ + chat.PlatformCode: owner.PublicKey().ToBase58(), + chat.PlatformTwitter: twitterUsername, + }, + ) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_OK, + Chat: protoChat, + }, nil + default: + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } + + default: + return nil, status.Error(codes.InvalidArgument, "StartChatRequest.Parameters is nil") + } +} + func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { log := s.log.WithField("method", "SendMessage") log = client.InjectLoggingMetadata(ctx, log) @@ -1059,6 +1140,105 @@ func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o return res, nil } +func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { + protoChat := &chatpb.ChatMetadata{ + ChatId: chatRecord.ChatId.ToProto(), + Kind: chatRecord.ChatType.ToProto(), + } + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + protoChat.Title = "Tip Chat" // todo: proper title with localization + + protoChat.CanMute = true + protoChat.CanUnsubscribe = true + default: + return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) + } + + for _, memberRecord := range memberRecords { + var isSelf bool + var identity *chatpb.ChatMemberIdentity + switch memberRecord.Platform { + case chat.PlatformCode: + myPublicKey, ok := myIdentitiesByPlatform[chat.PlatformCode] + isSelf = ok && myPublicKey == memberRecord.PlatformId + case chat.PlatformTwitter: + myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] + isSelf = ok && myTwitterUsername == memberRecord.PlatformId + + identity = &chatpb.ChatMemberIdentity{ + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + } + default: + return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) + } + + var pointers []*chatpb.Pointer + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + pointers = append(pointers, &chatpb.Pointer{ + Kind: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }) + } + + protoMember := &chatpb.ChatMember{ + MemberId: memberRecord.MemberId.ToProto(), + IsSelf: isSelf, + Identity: identity, + Pointers: pointers, + } + if protoMember.IsSelf { + protoMember.IsMuted = memberRecord.IsMuted + protoMember.IsSubscribed = !memberRecord.IsUnsubscribed + + if !memberRecord.IsUnsubscribed { + readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) + if memberRecord.ReadPointer != nil { + readPointer = *memberRecord.ReadPointer + } + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, memberRecord.MemberId, readPointer) + if err != nil { + return nil, errors.Wrap(err, "error calculating unread count") + } + protoMember.NumUnread = unreadCount + } + } + + protoChat.Members = append(protoChat.Members, protoMember) + } + + return protoChat, nil +} + +func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { + identities := map[chat.Platform]string{ + chat.PlatformCode: owner.PublicKey().ToBase58(), + } + + twitterUserame, ok, err := s.getOwnedTwitterUsername(ctx, owner) + if err != nil { + return nil, err + } + if ok { + identities[chat.PlatformTwitter] = twitterUserame + } + + return identities, nil +} + func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { @@ -1073,24 +1253,45 @@ func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberI case chat.PlatformCode: return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil case chat.PlatformTwitter: - // todo: This logic should live elsewhere in somewhere more common + return s.ownsTwitterUsername(ctx, owner, memberRecord.PlatformId) + default: + return false, nil + } +} - ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) - if err != nil { - return false, errors.Wrap(err, "error deriving twitter tip address") - } +// todo: This logic should live elsewhere in somewhere more common +func (s *server) ownsTwitterUsername(ctx context.Context, owner *common.Account, username string) (bool, error) { + ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) + if err != nil { + return false, errors.Wrap(err, "error deriving twitter tip address") + } - twitterRecord, err := s.data.GetTwitterUserByUsername(ctx, memberRecord.PlatformId) - switch err { - case nil: - case twitter.ErrUserNotFound: - return false, nil - default: - return false, errors.Wrap(err, "error getting twitter user") - } + twitterRecord, err := s.data.GetTwitterUserByUsername(ctx, username) + switch err { + case nil: + case twitter.ErrUserNotFound: + return false, nil + default: + return false, errors.Wrap(err, "error getting twitter user") + } - return twitterRecord.TipAddress == ownerTipAccount.PublicKey().ToBase58(), nil + return twitterRecord.TipAddress == ownerTipAccount.PublicKey().ToBase58(), nil +} + +// todo: This logic should live elsewhere in somewhere more common +func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Account) (string, bool, error) { + ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) + if err != nil { + return "", false, errors.Wrap(err, "error deriving twitter tip address") + } + + twitterRecord, err := s.data.GetTwitterUserByTipAddress(ctx, ownerTipAccount.PublicKey().ToBase58()) + switch err { + case nil: + return twitterRecord.Username, true, nil + case twitter.ErrUserNotFound: + return "", false, nil default: - return false, nil + return "", false, errors.Wrap(err, "error getting twitter user") } } From bd59f14e00b0dd30853fc263722edf98b9550e46 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 18 Jun 2024 13:32:55 -0400 Subject: [PATCH 41/71] Initial implementation of the RevealIdentity RPC --- pkg/code/data/chat/v2/memory/store.go | 27 ++++ pkg/code/data/chat/v2/model.go | 16 ++- pkg/code/data/chat/v2/store.go | 18 ++- pkg/code/data/internal.go | 4 + pkg/code/server/grpc/chat/v2/server.go | 189 +++++++++++++++++++++++-- pkg/code/server/grpc/chat/v2/stream.go | 4 - 6 files changed, 235 insertions(+), 23 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index 905872f3..ff54f9ac 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/pkg/errors" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/database/query" ) @@ -240,6 +242,31 @@ func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId c return false, nil } +// UpgradeIdentity implements chat.Store.UpgradeIdentity +func (s *store) UpgradeIdentity(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, platform chat.Platform, platformId string) error { + switch platform { + case chat.PlatformTwitter: + default: + return errors.Errorf("platform not supported for identity upgrades: %s", platform.String()) + } + + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findMemberById(chatId, memberId) + if item == nil { + return chat.ErrMemberNotFound + } + if item.Platform != chat.PlatformCode { + return chat.ErrMemberIdentityAlreadyUpgraded + } + + item.Platform = platform + item.PlatformId = platformId + + return nil +} + // SetMuteState implements chat.Store.SetMuteState func (s *store) SetMuteState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { s.mu.Lock() diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index e0cb5ed9..ef3c7071 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -192,12 +192,22 @@ func (p PointerType) String() string { } // ToProto returns the proto representation of the platform -func (p Platform) ToProto() chatpb.ChatMemberIdentity_Platform { +func GetPlatformFromProto(proto chatpb.Platform) Platform { + switch proto { + case chatpb.Platform_TWITTER: + return PlatformTwitter + default: + return PlatformUnknown + } +} + +// ToProto returns the proto representation of the platform +func (p Platform) ToProto() chatpb.Platform { switch p { case PlatformTwitter: - return chatpb.ChatMemberIdentity_TWITTER + return chatpb.Platform_TWITTER default: - return chatpb.ChatMemberIdentity_UNKNOWN + return chatpb.Platform_UNKNOWN_PLATFORM } } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 957260e8..a3fc4b43 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -8,13 +8,14 @@ import ( ) var ( - ErrChatExists = errors.New("chat already exists") - ErrChatNotFound = errors.New("chat not found") - ErrMemberExists = errors.New("chat member already exists") - ErrMemberNotFound = errors.New("chat member not found") - ErrMessageExsits = errors.New("chat message already exists") - ErrMessageNotFound = errors.New("chat message not found") - ErrInvalidPointerType = errors.New("invalid pointer type") + ErrChatExists = errors.New("chat already exists") + ErrChatNotFound = errors.New("chat not found") + ErrMemberExists = errors.New("chat member already exists") + ErrMemberNotFound = errors.New("chat member not found") + ErrMemberIdentityAlreadyUpgraded = errors.New("chat member identity already upgraded") + ErrMessageExsits = errors.New("chat message already exists") + ErrMessageNotFound = errors.New("chat message not found") + ErrInvalidPointerType = errors.New("invalid pointer type") ) // todo: Define interface methods @@ -56,6 +57,9 @@ type Store interface { // AdvancePointer advances a chat pointer for a chat member AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) + // UpgradeIdentity upgrades a chat member's identity from an anonymous state + UpgradeIdentity(ctx context.Context, chatId ChatId, memberId MemberId, platform Platform, platformId string) error + // SetMuteState updates the mute state for a chat member SetMuteState(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 11808d3e..c043a8f2 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -407,6 +407,7 @@ type DatabaseData interface { PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) + UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error @@ -1507,6 +1508,9 @@ func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) } +func (dp *DatabaseProvider) UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error { + return dp.chatv2.UpgradeIdentity(ctx, chatId, memberId, platform, platformId) +} func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { return dp.chatv2.SetMuteState(ctx, chatId, memberId, isMuted) } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index e864ca02..021608f0 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -36,6 +36,8 @@ import ( sync_util "github.com/code-payments/code-server/pkg/sync" ) +// todo: resolve some common code for sending chat messages across RPCs + const ( maxGetChatsPageSize = 100 maxGetMessagesPageSize = 100 @@ -221,7 +223,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, status.Error(codes.Internal, "") } - ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") @@ -337,7 +339,7 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e return status.Error(codes.Internal, "") } - ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return status.Error(codes.Internal, "") @@ -518,7 +520,7 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * } func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { - log := s.log.WithField("method", "SendMessage") + log := s.log.WithField("method", "StartChat") log = client.InjectLoggingMetadata(ctx, log) owner, err := common.NewAccountFromProto(req.Owner) @@ -751,7 +753,7 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest }, nil } - ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") @@ -910,7 +912,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.Internal, "") } - ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") @@ -954,6 +956,171 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR }, nil } +func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityRequest) (*chatpb.RevealIdentityResponse, error) { + log := s.log.WithField("method", "RevealIdentity") + log = client.InjectLoggingMetadata(ctx, log) + + owner, err := common.NewAccountFromProto(req.Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := chat.GetMemberIdFromProto(req.MemberId) + if err != nil { + log.WithError(err).Warn("invalid member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.Signature + req.Signature = nil + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + return nil, err + } + + platform := chat.GetPlatformFromProto(req.Identity.Platform) + + log = log.WithFields(logrus.Fields{ + "platform": platform.String(), + "username": req.Identity.Username, + }) + + _, err = s.data.GetChatByIdV2(ctx, chatId) + switch err { + case nil: + case chat.ErrChatNotFound: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_CHAT_NOT_FOUND, + }, nil + default: + log.WithError(err).Warn("failure getting chat record") + return nil, status.Error(codes.Internal, "") + } + + memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) + switch err { + case nil: + case chat.ErrMemberNotFound: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + default: + log.WithError(err).Warn("failure getting member record") + return nil, status.Error(codes.Internal, "") + } + + ownsChatMember, err := s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) + if err != nil { + log.WithError(err).Warn("failure determing chat member ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsChatMember { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + } + + switch platform { + case chat.PlatformTwitter: + ownsUsername, err := s.ownsTwitterUsername(ctx, owner, req.Identity.Username) + if err != nil { + log.WithError(err).Warn("failure determing twitter username ownership") + return nil, status.Error(codes.Internal, "") + } else if !ownsUsername { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DENIED, + }, nil + } + default: + return nil, status.Error(codes.InvalidArgument, "RevealIdentityRequest.Identity.Platform must be TWITTER") + } + + // Idempotent RPC call using the same platform and username + if memberRecord.Platform == platform && memberRecord.PlatformId == req.Identity.Username { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_OK, + }, nil + } + + // Identity was already revealed, and it isn't the specified platform and username + if memberRecord.Platform != chat.PlatformCode { + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, + }, nil + } + + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + chatMessage := &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: req.MemberId, + Content: []*chatpb.Content{ + { + Type: &chatpb.Content_IdentityRevealed{}, + }, + }, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } + + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) + switch err { + case nil: + case chat.ErrMemberIdentityAlreadyUpgraded: + return err + default: + return errors.Wrap(err, "error updating chat member identity") + } + + err := s.persistChatMessage(ctx, chatId, chatMessage) + if err != nil { + return errors.Wrap(err, "error persisting chat message") + } + return nil + }) + + if err == nil { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, + } + if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + // todo: send the push + } + + switch err { + case nil: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_OK, + }, nil + case chat.ErrMemberIdentityAlreadyUpgraded: + return &chatpb.RevealIdentityResponse{ + Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, + }, nil + default: + log.WithError(err).Warn("failure upgrading chat member identity") + return nil, status.Error(codes.Internal, "") + } +} + func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { log := s.log.WithField("method", "SetMuteState") log = client.InjectLoggingMetadata(ctx, log) @@ -998,11 +1165,11 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque return nil, status.Error(codes.Internal, "") } - isChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") - } else if !isChatMember { + } else if !ownsChatMember { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_DENIED, }, nil @@ -1063,7 +1230,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr return nil, status.Error(codes.Internal, "") } - ownsChatMember, err := s.ownsChatMember(ctx, chatId, memberId, owner) + ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") @@ -1239,7 +1406,7 @@ func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (m return identities, nil } -func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { +func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { case nil: @@ -1249,6 +1416,10 @@ func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberI return false, errors.Wrap(err, "error getting member record") } + return s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) +} + +func (s *server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatId, memberRecord *chat.MemberRecord, owner *common.Account) (bool, error) { switch memberRecord.Platform { case chat.PlatformCode: return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go index ec797a77..17d69b6a 100644 --- a/pkg/code/server/grpc/chat/v2/stream.go +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -133,10 +133,6 @@ func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan inter continue } - if strings.HasSuffix(key, typedValue.memberId.String()) { - continue - } - if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) } From bbe23e91be8dbf9e7f7db6d49b9fb6204d124dfe Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 19 Jun 2024 09:54:18 -0400 Subject: [PATCH 42/71] More refactors of chat stuff --- pkg/code/server/grpc/chat/v2/server.go | 203 ++++++++++++------------- pkg/code/server/grpc/chat/v2/stream.go | 11 +- 2 files changed, 105 insertions(+), 109 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 021608f0..09abdabc 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -321,8 +321,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e signature := req.GetOpenStream().Signature req.GetOpenStream().Signature = nil - if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { - // return err + if err := s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err } _, err = s.data.GetChatByIdV2(ctx, chatId) @@ -532,7 +532,7 @@ func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* signature := req.Signature req.Signature = nil - if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } @@ -729,7 +729,7 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest signature := req.Signature req.Signature = nil - if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } @@ -767,16 +767,7 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest chatLock.Lock() defer chatLock.Unlock() - messageId := chat.GenerateMessageId() - ts, _ := messageId.GetTimestamp() - - chatMessage := &chatpb.ChatMessage{ - MessageId: messageId.ToProto(), - SenderId: req.MemberId, - Content: req.Content, - Ts: timestamppb.New(ts), - Cursor: &chatpb.Cursor{Value: messageId[:]}, - } + chatMessage := newProtoChatMessage(memberId, req.Content...) err = s.persistChatMessage(ctx, chatId, chatMessage) if err != nil { @@ -784,16 +775,7 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, status.Error(codes.Internal, "") } - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: chatMessage, - }, - } - if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") - } - - // todo: send the push + s.onPersistChatMessage(log, chatId, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, @@ -946,7 +928,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR Pointer: req.Pointer, }, } - if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + if err := s.asyncNotifyAll(chatId, event); err != nil { log.WithError(err).Warn("failure notifying chat event") } } @@ -1061,20 +1043,17 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR chatLock.Lock() defer chatLock.Unlock() - messageId := chat.GenerateMessageId() - ts, _ := messageId.GetTimestamp() - - chatMessage := &chatpb.ChatMessage{ - MessageId: messageId.ToProto(), - SenderId: req.MemberId, - Content: []*chatpb.Content{ - { - Type: &chatpb.Content_IdentityRevealed{}, + chatMessage := newProtoChatMessage( + memberId, + &chatpb.Content{ + Type: &chatpb.Content_IdentityRevealed{ + IdentityRevealed: &chatpb.IdentityRevealedContent{ + MemberId: req.MemberId, + Identity: req.Identity, + }, }, }, - Ts: timestamppb.New(ts), - Cursor: &chatpb.Cursor{Value: messageId[:]}, - } + ) err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) @@ -1094,16 +1073,7 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR }) if err == nil { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: chatMessage, - }, - } - if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") - } - - // todo: send the push + s.onPersistChatMessage(log, chatId, chatMessage) } switch err { @@ -1251,66 +1221,11 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } -func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { - messageRecords, err := s.data.GetAllChatMessagesV2( - ctx, - chatId, - queryOptions..., - ) - if err == chat.ErrMessageNotFound { - return nil, err - } - - var userLocale *language.Tag // Loaded lazily when required - var res []*chatpb.ChatMessage - for _, messageRecord := range messageRecords { - var protoChatMessage chatpb.ChatMessage - err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) - if err != nil { - return nil, errors.Wrap(err, "error unmarshalling proto chat message") - } - - ts, err := messageRecord.GetTimestamp() - if err != nil { - return nil, errors.Wrap(err, "error getting message timestamp") - } - - for _, content := range protoChatMessage.Content { - switch typed := content.Type.(type) { - case *chatpb.Content_Localized: - if userLocale == nil { - loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) - if err != nil { - return nil, errors.Wrap(err, "error getting user locale") - } - userLocale = &loadedUserLocale - } - - typed.Localized.KeyOrText = localization.LocalizeWithFallback( - *userLocale, - localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), - typed.Localized.KeyOrText, - ) - } - } - - protoChatMessage.MessageId = messageRecord.MessageId.ToProto() - if messageRecord.Sender != nil { - protoChatMessage.SenderId = messageRecord.Sender.ToProto() - } - protoChatMessage.Ts = timestamppb.New(ts) - protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} - - res = append(res, &protoChatMessage) - } - - return res, nil -} - func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { protoChat := &chatpb.ChatMetadata{ ChatId: chatRecord.ChatId.ToProto(), Kind: chatRecord.ChatType.ToProto(), + Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, } switch chatRecord.ChatType { @@ -1390,6 +1305,75 @@ func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m return protoChat, nil } +func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { + messageRecords, err := s.data.GetAllChatMessagesV2( + ctx, + chatId, + queryOptions..., + ) + if err == chat.ErrMessageNotFound { + return nil, err + } + + var userLocale *language.Tag // Loaded lazily when required + var res []*chatpb.ChatMessage + for _, messageRecord := range messageRecords { + var protoChatMessage chatpb.ChatMessage + err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) + if err != nil { + return nil, errors.Wrap(err, "error unmarshalling proto chat message") + } + + ts, err := messageRecord.GetTimestamp() + if err != nil { + return nil, errors.Wrap(err, "error getting message timestamp") + } + + for _, content := range protoChatMessage.Content { + switch typed := content.Type.(type) { + case *chatpb.Content_Localized: + if userLocale == nil { + loadedUserLocale, err := s.data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + return nil, errors.Wrap(err, "error getting user locale") + } + userLocale = &loadedUserLocale + } + + typed.Localized.KeyOrText = localization.LocalizeWithFallback( + *userLocale, + localization.GetLocalizationKeyForUserAgent(ctx, typed.Localized.KeyOrText), + typed.Localized.KeyOrText, + ) + } + } + + protoChatMessage.MessageId = messageRecord.MessageId.ToProto() + if messageRecord.Sender != nil { + protoChatMessage.SenderId = messageRecord.Sender.ToProto() + } + protoChatMessage.Ts = timestamppb.New(ts) + protoChatMessage.Cursor = &chatpb.Cursor{Value: messageRecord.MessageId[:]} + + res = append(res, &protoChatMessage) + } + + return res, nil +} + +func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, + } + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + // todo: send the push +} + func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { identities := map[chat.Platform]string{ chat.PlatformCode: owner.PublicKey().ToBase58(), @@ -1466,3 +1450,16 @@ func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Acco return "", false, errors.Wrap(err, "error getting twitter user") } } + +func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.ChatMessage { + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + return &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: sender.ToProto(), + Content: content, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } +} diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go index 17d69b6a..3d39428d 100644 --- a/pkg/code/server/grpc/chat/v2/stream.go +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -93,15 +93,14 @@ func boundedStreamChatEventsRecv( } type chatEventNotification struct { - chatId chat.ChatId - memberId chat.MemberId - event *chatpb.ChatStreamEvent - ts time.Time + chatId chat.ChatId + event *chatpb.ChatStreamEvent + ts time.Time } -func (s *server) asyncNotifyAll(chatId chat.ChatId, memberId chat.MemberId, event *chatpb.ChatStreamEvent) error { +func (s *server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEvent) error { m := proto.Clone(event).(*chatpb.ChatStreamEvent) - ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, memberId, m, time.Now()}) + ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, m, time.Now()}) if !ok { return errors.New("chat event channel is full") } From 3f35885dac85b004da0f5d11f333af7c06eea980 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 20 Jun 2024 13:58:28 -0400 Subject: [PATCH 43/71] Incorporate small tweaks to chat APIs --- pkg/code/server/grpc/chat/v2/server.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 09abdabc..bc5351a2 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -505,7 +505,7 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * event := &chatpb.ChatStreamEvent{ Type: &chatpb.ChatStreamEvent_Pointer{ Pointer: &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), + Type: optionalPointer.kind.ToProto(), Value: optionalPointer.value.ToProto(), MemberId: memberRecord.MemberId.ToProto(), }, @@ -859,7 +859,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("member_id", memberId.String()) - pointerType := chat.GetPointerTypeFromProto(req.Pointer.Kind) + pointerType := chat.GetPointerTypeFromProto(req.Pointer.Type) log = log.WithField("pointer_type", pointerType.String()) switch pointerType { case chat.PointerTypeDelivered, chat.PointerTypeRead: @@ -1079,7 +1079,8 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR switch err { case nil: return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_OK, + Result: chatpb.RevealIdentityResponse_OK, + Message: chatMessage, }, nil case chat.ErrMemberIdentityAlreadyUpgraded: return &chatpb.RevealIdentityResponse{ @@ -1224,7 +1225,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { protoChat := &chatpb.ChatMetadata{ ChatId: chatRecord.ChatId.ToProto(), - Kind: chatRecord.ChatType.ToProto(), + Type: chatRecord.ChatType.ToProto(), Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, } @@ -1270,7 +1271,7 @@ func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m } pointers = append(pointers, &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), + Type: optionalPointer.kind.ToProto(), Value: optionalPointer.value.ToProto(), MemberId: memberRecord.MemberId.ToProto(), }) From cdef2230a1fccbe7ac60cda97992db2f452af451 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 15 May 2024 16:09:26 -0400 Subject: [PATCH 44/71] PoC in memory two way messaging --- pkg/code/server/grpc/chat/stream.go | 121 +++++++++++++++++++++++++ pkg/code/server/grpc/chat/v1/server.go | 114 ++++++++--------------- 2 files changed, 161 insertions(+), 74 deletions(-) create mode 100644 pkg/code/server/grpc/chat/stream.go diff --git a/pkg/code/server/grpc/chat/stream.go b/pkg/code/server/grpc/chat/stream.go new file mode 100644 index 00000000..9d79969b --- /dev/null +++ b/pkg/code/server/grpc/chat/stream.go @@ -0,0 +1,121 @@ +package chat + +import ( + "context" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" +) + +const ( + // todo: configurable + streamBufferSize = 64 + streamPingDelay = 5 * time.Second + streamKeepAliveRecvTimeout = 10 * time.Second + streamNotifyTimeout = 10 * time.Second +) + +type chatEventStream struct { + sync.Mutex + + closed bool + streamCh chan *chatpb.ChatStreamEvent +} + +func newChatEventStream(bufferSize int) *chatEventStream { + return &chatEventStream{ + streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), + } +} + +func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { + m := proto.Clone(event).(*chatpb.ChatStreamEvent) + + s.Lock() + + if s.closed { + s.Unlock() + return errors.New("cannot notify closed stream") + } + + select { + case s.streamCh <- m: + case <-time.After(timeout): + s.Unlock() + s.close() + return errors.New("timed out sending message to streamCh") + } + + s.Unlock() + return nil +} + +func (s *chatEventStream) close() { + s.Lock() + defer s.Unlock() + + if s.closed { + return + } + + s.closed = true + close(s.streamCh) +} + +func boundedStreamChatEventsRecv( + ctx context.Context, + streamer chatpb.Chat_StreamChatEventsServer, + timeout time.Duration, +) (req *chatpb.StreamChatEventsRequest, err error) { + done := make(chan struct{}) + go func() { + req, err = streamer.Recv() + close(done) + }() + + select { + case <-done: + return req, err + case <-ctx.Done(): + return nil, status.Error(codes.Canceled, "") + case <-time.After(timeout): + return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") + } +} + +// Very naive implementation to start +func monitorChatEventStreamHealth( + ctx context.Context, + log *logrus.Entry, + ssRef string, + streamer chatpb.Chat_StreamChatEventsServer, +) <-chan struct{} { + streamHealthChan := make(chan struct{}) + go func() { + defer close(streamHealthChan) + + for { + // todo: configurable timeout + req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + if err != nil { + return + } + + switch req.Type.(type) { + case *chatpb.StreamChatEventsRequest_Pong: + log.Tracef("received pong from client (stream=%s)", ssRef) + default: + // Client sent something unexpected. Terminate the stream + return + } + } + }() + return streamHealthChan +} diff --git a/pkg/code/server/grpc/chat/v1/server.go b/pkg/code/server/grpc/chat/v1/server.go index 5fef09f9..40a5c0f6 100644 --- a/pkg/code/server/grpc/chat/v1/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -54,18 +54,18 @@ type server struct { chatLocks *sync_util.StripedLock chatEventChans *sync_util.StripedChannel + streamsMu sync.RWMutex + streams map[string]*chatEventStream + chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { - s := &server{ - log: logrus.StandardLogger().WithField("type", "chat/v1/server"), - data: data, - auth: auth, - pusher: pusher, - streams: make(map[string]*chatEventStream), - chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters - chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { + return &server{ + log: logrus.StandardLogger().WithField("type", "chat/server"), + data: data, + auth: auth, + streams: make(map[string]*chatEventStream), } for i, channel := range s.chatEventChans.GetChannels() { @@ -403,9 +403,21 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR Pointers: []*chatpb.Pointer{req.Pointer}, } - if err := s.asyncNotifyAll(chatId, owner, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { + continue + } + + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } } + s.streamsMu.RUnlock() return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_OK, @@ -416,7 +428,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } - chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) + chatRecord, err := s.data.GetChatById(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, @@ -644,22 +656,6 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e s.streamsMu.Unlock() - defer func() { - s.streamsMu.Lock() - - log.Tracef("closing streamer (stream=%s)", streamRef) - - // We check to see if the current active stream is the one that we created. - // If it is, we can just remove it since it's closed. Otherwise, we leave it - // be, as another OpenMessageStream() call is handling it. - liveStream, exists := s.streams[streamKey] - if exists && liveStream == stream { - delete(s.streams, streamKey) - } - - s.streamsMu.Unlock() - }() - sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) @@ -733,15 +729,11 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } switch req.Content[0].Type.(type) { - case *chatpb.Content_Text, *chatpb.Content_ThankYou: + case *chatpb.Content_UserText: default: - return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") + return nil, status.Error(codes.InvalidArgument, "content[0] must be UserText") } - chatLock := s.chatLocks.Get(chatId[:]) - chatLock.Lock() - defer chatLock.Unlock() - // todo: Revisit message IDs messageId, err := common.NewRandomAccount() if err != nil { @@ -757,54 +749,28 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest Cursor: nil, // todo: Don't have cursor until we save it to the DB } - // todo: Save the message to the DB - event := &chatpb.ChatStreamEvent{ Messages: []*chatpb.ChatMessage{chatMessage}, } - if err := s.asyncNotifyAll(chatId, owner, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") - } + s.streamsMu.RLock() + for key, stream := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { + continue + } - s.asyncPushChatMessage(owner, chatId, chatMessage) + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, Message: chatMessage, }, nil } - -// todo: doesn't respect mute/unsubscribe rules -// todo: only sends pushes to active stream listeners instead of all message recipients -func (s *server) asyncPushChatMessage(sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { - ctx := context.TODO() - - go func() { - s.streamsMu.RLock() - for key := range s.streams { - if !strings.HasPrefix(key, chatId.String()) { - continue - } - - receiver, err := common.NewAccountFromPublicKeyString(strings.Split(key, ":")[1]) - if err != nil { - continue - } - - if bytes.Equal(sender.PublicKey().ToBytes(), receiver.PublicKey().ToBytes()) { - continue - } - - go push_util.SendChatMessagePushNotification( - ctx, - s.data, - s.pusher, - "TontonTwitch", - receiver, - chatMessage, - ) - } - s.streamsMu.RUnlock() - }() -} From e0766f4f3a6b8c93769110cc2b296fae8ab3506a Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 21 May 2024 14:11:55 -0400 Subject: [PATCH 45/71] Add support for thank you messages --- pkg/code/server/grpc/chat/v1/server.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/code/server/grpc/chat/v1/server.go b/pkg/code/server/grpc/chat/v1/server.go index 40a5c0f6..49859416 100644 --- a/pkg/code/server/grpc/chat/v1/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -729,9 +729,9 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } switch req.Content[0].Type.(type) { - case *chatpb.Content_UserText: + case *chatpb.Content_Text, *chatpb.Content_ThankYou: default: - return nil, status.Error(codes.InvalidArgument, "content[0] must be UserText") + return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") } // todo: Revisit message IDs From c2aff7b4f54bea3238c708e084f4ffbab6d4308b Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 11 Jun 2024 11:14:13 -0400 Subject: [PATCH 46/71] Add missing result codes and update/comment on flush --- go.mod | 4 ++++ go.sum | 5 +++++ pkg/code/server/grpc/chat/v2/server.go | 6 +++--- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 4e808287..54ecf00d 100644 --- a/go.mod +++ b/go.mod @@ -6,8 +6,12 @@ require ( firebase.google.com/go/v4 v4.8.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 +<<<<<<< HEAD github.com/code-payments/code-protobuf-api v1.19.0 github.com/dghubble/oauth1 v0.7.3 +======= + github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73 +>>>>>>> ae8bddd (Add missing result codes and update/comment on flush) github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/golang-jwt/jwt/v5 v5.0.0 diff --git a/go.sum b/go.sum index 7414b428..98dad0b3 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,13 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +<<<<<<< HEAD github.com/code-payments/code-protobuf-api v1.19.0 h1:md/eJhqltz8dDY0U8hwT/42C3h+kP+W/68D7RMSjqPo= github.com/code-payments/code-protobuf-api v1.19.0/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +======= +github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73 h1:gdj/RvbLkcfxeWsrHJSu6Z8rkNtWvrIMZz/1WQlxVyg= +github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +>>>>>>> ae8bddd (Add missing result codes and update/comment on flush) github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index bc5351a2..3c6180b8 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -216,7 +216,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest case nil: case chat.ErrChatNotFound: return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_CHAT_NOT_FOUND, + Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, }, nil default: log.WithError(err).Warn("failure getting chat record") @@ -392,8 +392,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) + // todo: We should also "flush" pointers for each chat member go s.flushMessages(ctx, chatId, owner, stream) - go s.flushPointers(ctx, chatId, stream) for { select { @@ -1140,7 +1140,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { + } else if !isChatMember { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_DENIED, }, nil From 202e51c1457a332416f2a0d68f46960e88eba874 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 13 Jun 2024 10:40:30 -0400 Subject: [PATCH 47/71] Implement GetChats RPC limited to anonymous chat membership --- go.mod | 4 - go.sum | 5 - pkg/code/data/chat/v2/memory/store.go | 28 ++---- pkg/code/data/chat/v2/model.go | 54 ++++++++--- pkg/code/data/chat/v2/store.go | 9 +- pkg/code/data/internal.go | 12 +-- pkg/code/server/grpc/chat/v2/server.go | 126 ++++++++++++++++++++++--- 7 files changed, 168 insertions(+), 70 deletions(-) diff --git a/go.mod b/go.mod index 54ecf00d..4e808287 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,8 @@ require ( firebase.google.com/go/v4 v4.8.0 github.com/aws/aws-sdk-go-v2 v0.17.0 github.com/bits-and-blooms/bloom/v3 v3.1.0 -<<<<<<< HEAD github.com/code-payments/code-protobuf-api v1.19.0 github.com/dghubble/oauth1 v0.7.3 -======= - github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73 ->>>>>>> ae8bddd (Add missing result codes and update/comment on flush) github.com/emirpasic/gods v1.12.0 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/golang-jwt/jwt/v5 v5.0.0 diff --git a/go.sum b/go.sum index 98dad0b3..7414b428 100644 --- a/go.sum +++ b/go.sum @@ -121,13 +121,8 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -<<<<<<< HEAD github.com/code-payments/code-protobuf-api v1.19.0 h1:md/eJhqltz8dDY0U8hwT/42C3h+kP+W/68D7RMSjqPo= github.com/code-payments/code-protobuf-api v1.19.0/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= -======= -github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73 h1:gdj/RvbLkcfxeWsrHJSu6Z8rkNtWvrIMZz/1WQlxVyg= -github.com/code-payments/code-protobuf-api v1.16.7-0.20240611151313-ca7587f92a73/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= ->>>>>>> ae8bddd (Add missing result codes and update/comment on flush) github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index ff54f9ac..0967d5be 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -82,12 +82,12 @@ func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([] return cloneMemberRecords(items), nil } -// GetAllMembersByPlatformIds implements chat.store.GetAllMembersByPlatformIds -func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[chat.Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { +// GetAllMembersByPlatformId implements chat.store.GetAllMembersByPlatformId +func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { s.mu.Lock() defer s.mu.Unlock() - items := s.findMembersByPlatformIds(idByPlatform) + items := s.findMembersByPlatformId(platform, platformId) items, err := s.getMemberRecordPage(items, cursor, direction, limit) if err != nil { return nil, err @@ -100,13 +100,12 @@ func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[c } // GetUnreadCount implements chat.store.GetUnreadCount -func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) { +func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, readPointer chat.MessageId) (uint32, error) { s.mu.Lock() defer s.mu.Unlock() items := s.findMessagesByChatId(chatId) items = s.filterMessagesAfter(items, readPointer) - items = s.filterMessagesNotSentBy(items, memberId) items = s.filterNotifiedMessages(items) return uint32(len(items)), nil } @@ -351,15 +350,10 @@ func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { return res } -func (s *store) findMembersByPlatformIds(idByPlatform map[chat.Platform]string) []*chat.MemberRecord { +func (s *store) findMembersByPlatformId(platform chat.Platform, platformId string) []*chat.MemberRecord { var res []*chat.MemberRecord for _, item := range s.memberRecords { - platformId, ok := idByPlatform[item.Platform] - if !ok { - continue - } - - if platformId == item.PlatformId { + if platform == item.Platform && platformId == item.PlatformId { res = append(res, item) } } @@ -447,16 +441,6 @@ func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.Me return res } -func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord { - var res []*chat.MessageRecord - for _, item := range items { - if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) { - res = append(res, item) - } - } - return res -} - func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { var res []*chat.MessageRecord for _, item := range items { diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index ef3c7071..b1132706 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -114,7 +114,43 @@ func (a MessagesByMessageId) Less(i, j int) bool { } // GetChatTypeFromProto gets a chat type from the protobuf variant -func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { +func GetChatTypeFromProto(proto chatpb.ChatMetadata_Kind) ChatType { + switch proto { + case chatpb.ChatMetadata_NOTIFICATION: + return ChatTypeNotification + case chatpb.ChatMetadata_TWO_WAY: + return ChatTypeTwoWay + default: + return ChatTypeUnknown + } +} + +// ToProto returns the proto representation of the chat type +func (c ChatType) ToProto() chatpb.ChatMetadata_Kind { + switch c { + case ChatTypeNotification: + return chatpb.ChatMetadata_NOTIFICATION + case ChatTypeTwoWay: + return chatpb.ChatMetadata_TWO_WAY + default: + return chatpb.ChatMetadata_UNKNOWN + } +} + +// String returns the string representation of the chat type +func (c ChatType) String() string { + switch c { + case ChatTypeNotification: + return "notification" + case ChatTypeTwoWay: + return "two-way" + default: + return "unknown" + } +} + +// GetPointerTypeFromProto gets a chat ID from the protobuf variant +func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { switch proto { case chatpb.ChatType_NOTIFICATION: return ChatTypeNotification @@ -192,22 +228,12 @@ func (p PointerType) String() string { } // ToProto returns the proto representation of the platform -func GetPlatformFromProto(proto chatpb.Platform) Platform { - switch proto { - case chatpb.Platform_TWITTER: - return PlatformTwitter - default: - return PlatformUnknown - } -} - -// ToProto returns the proto representation of the platform -func (p Platform) ToProto() chatpb.Platform { +func (p Platform) ToProto() chatpb.ChatMemberIdentity_Platform { switch p { case PlatformTwitter: - return chatpb.Platform_TWITTER + return chatpb.ChatMemberIdentity_TWITTER default: - return chatpb.Platform_UNKNOWN_PLATFORM + return chatpb.ChatMemberIdentity_UNKNOWN } } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index a3fc4b43..f3a4ffda 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -34,16 +34,17 @@ type Store interface { // todo: Add paging when we introduce group chats GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) - // GetAllMembersByPlatformIds gets all members for platform users across all chats - GetAllMembersByPlatformIds(ctx context.Context, idByPlatform map[Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + // GetAllMembersByPlatformId gets all members for a given platform user across + // all chats + GetAllMembersByPlatformId(ctx context.Context, platform Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) // GetAllMessagesByChatId gets all messages for a given chat // // Note: Cursor is a message ID GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) - // GetUnreadCount gets the unread message count for a chat ID at a read pointer for a given chat member - GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error) + // GetUnreadCount gets the unread message count for a chat ID at a read pointer + GetUnreadCount(ctx context.Context, chatId ChatId, readPointer MessageId) (uint32, error) // PutChat creates a new chat PutChat(ctx context.Context, record *ChatRecord) error diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index c043a8f2..d0cae22f 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -400,9 +400,9 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) - GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) + GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) - GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error @@ -1479,12 +1479,12 @@ func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId cha func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { return dp.chatv2.GetAllMembersByChatId(ctx, chatId) } -func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { +func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chatv2.GetAllMembersByPlatformIds(ctx, idByPlatform, req.Cursor, req.SortBy, req.Limit) + return dp.chatv2.GetAllMembersByPlatformId(ctx, platform, platformId, req.Cursor, req.SortBy, req.Limit) } func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { req, err := query.DefaultPaginationHandler(opts...) @@ -1493,8 +1493,8 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) { - return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer) +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetUnreadCount(ctx, chatId, readPointer) } func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { return dp.chatv2.PutChat(ctx, record) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 3c6180b8..e19f0bc4 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1,6 +1,7 @@ package chat_v2 import ( + "bytes" "context" "crypto/rand" "database/sql" @@ -79,6 +80,39 @@ func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier return s } +func (s *server) setupMockChat() { + ctx := context.Background() + + chatId, _ := chat.GetChatIdFromString("c355fcec8c521e7937d45283d83bbfc63a0c688004f2386a535fc817218f917b") + chatRecord := &chat.ChatRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + IsVerified: true, + CreatedAt: time.Now(), + } + s.data.PutChatV2(ctx, chatRecord) + + memberId1, _ := chat.GetMemberIdFromString("034dda45-b4c2-45db-b1da-181298898a16") + memberRecord1 := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId1, + Platform: chat.PlatformCode, + PlatformId: "8bw4gaRQk91w7vtgTN4E12GnKecY2y6CjPai7WUvWBQ8", + JoinedAt: time.Now(), + } + s.data.PutChatMemberV2(ctx, memberRecord1) + + memberId2, _ := chat.GetMemberIdFromString("a9d27058-f2d8-4034-bf52-b20c09a670de") + memberRecord2 := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId2, + Platform: chat.PlatformCode, + PlatformId: "EDknQfoUnj73L56vKtEc6Qqw5VoHaF32eHYdz3V4y27M", + JoinedAt: time.Now(), + } + s.data.PutChatMemberV2(ctx, memberRecord2) +} + // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") @@ -124,17 +158,10 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } - myIdentities, err := s.getAllIdentities(ctx, owner) - if err != nil { - log.WithError(err).Warn("failure getting identities for owner account") - return nil, status.Error(codes.Internal, "") - } - - // todo: Use a better query that returns chat IDs. This will result in duplicate - // chat results if the user is in the chat multiple times across many identities. patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( ctx, - myIdentities, + chat.PlatformCode, // todo: support other platforms once we support revealing identity + owner.PublicKey().ToBase58(), query.WithCursor(cursor), query.WithDirection(direction), query.WithLimit(limit), @@ -158,18 +185,87 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch return nil, status.Error(codes.Internal, "") } - memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) + protoChat := &chatpb.ChatMetadata{ + ChatId: chatRecord.ChatId.ToProto(), + Kind: chatRecord.ChatType.ToProto(), + + IsMuted: platformUserMemberRecord.IsMuted, + IsSubscribed: !platformUserMemberRecord.IsUnsubscribed, + + Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))}, + } + + // Unread count calculations can be skipped for unsubscribed chats. They + // don't appear in chat history. + skipUnreadCountQuery := platformUserMemberRecord.IsUnsubscribed + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + protoChat.Title = "Mock Chat" // todo: proper title with localization + + protoChat.CanMute = true + protoChat.CanUnsubscribe = true + default: + return nil, status.Errorf(codes.Unimplemented, "unsupported chat type: %s", chatRecord.ChatType.String()) + } + + chatMemberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting chat members") return nil, status.Error(codes.Internal, "") } + for _, memberRecord := range chatMemberRecords { + var identity *chatpb.ChatMemberIdentity + switch memberRecord.Platform { + case chat.PlatformCode: + case chat.PlatformTwitter: + identity = &chatpb.ChatMemberIdentity{ + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + } + default: + return nil, status.Errorf(codes.Unimplemented, "unsupported platform type: %s", memberRecord.Platform.String()) + } - protoChat, err := s.toProtoChat(ctx, chatRecord, memberRecords, myIdentities) - if err != nil { - log.WithError(err).Warn("failure constructing proto chat message") - return nil, status.Error(codes.Internal, "") + var pointers []*chatpb.Pointer + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + pointers = append(pointers, &chatpb.Pointer{ + Kind: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }) + } + + protoChat.Members = append(protoChat.Members, &chatpb.ChatMember{ + MemberId: memberRecord.MemberId.ToProto(), + IsSelf: bytes.Equal(memberRecord.MemberId[:], platformUserMemberRecord.MemberId[:]), + Identity: identity, + Pointers: pointers, + }) + } + + if !skipUnreadCountQuery { + readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) + if platformUserMemberRecord.ReadPointer != nil { + readPointer = *platformUserMemberRecord.ReadPointer + } + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, readPointer) + if err != nil { + log.WithError(err).Warn("failure getting unread count") + return nil, status.Error(codes.Internal, "") + } + protoChat.NumUnread = unreadCount } - protoChat.Cursor = &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))} protoChats = append(protoChats, protoChat) } From e49e233eacca29b49f3a14816894cf3f60d3901a Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Mon, 17 Jun 2024 13:23:58 -0400 Subject: [PATCH 48/71] Fix build with refactor changes to chat protos --- pkg/code/data/chat/v2/model.go | 38 +--------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index b1132706..e0cb5ed9 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -114,43 +114,7 @@ func (a MessagesByMessageId) Less(i, j int) bool { } // GetChatTypeFromProto gets a chat type from the protobuf variant -func GetChatTypeFromProto(proto chatpb.ChatMetadata_Kind) ChatType { - switch proto { - case chatpb.ChatMetadata_NOTIFICATION: - return ChatTypeNotification - case chatpb.ChatMetadata_TWO_WAY: - return ChatTypeTwoWay - default: - return ChatTypeUnknown - } -} - -// ToProto returns the proto representation of the chat type -func (c ChatType) ToProto() chatpb.ChatMetadata_Kind { - switch c { - case ChatTypeNotification: - return chatpb.ChatMetadata_NOTIFICATION - case ChatTypeTwoWay: - return chatpb.ChatMetadata_TWO_WAY - default: - return chatpb.ChatMetadata_UNKNOWN - } -} - -// String returns the string representation of the chat type -func (c ChatType) String() string { - switch c { - case ChatTypeNotification: - return "notification" - case ChatTypeTwoWay: - return "two-way" - default: - return "unknown" - } -} - -// GetPointerTypeFromProto gets a chat ID from the protobuf variant -func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType { +func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { switch proto { case chatpb.ChatType_NOTIFICATION: return ChatTypeNotification From 9f25c48e64831035fbfb7d2f601a54afb07b4d37 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 18 Jun 2024 11:55:46 -0400 Subject: [PATCH 49/71] Initial implementation of StartChat that always starts a new chat --- pkg/code/data/chat/v2/memory/store.go | 15 +- pkg/code/data/chat/v2/store.go | 5 +- pkg/code/data/internal.go | 6 +- pkg/code/server/grpc/chat/v2/server.go | 235 +++++++++++-------------- 4 files changed, 113 insertions(+), 148 deletions(-) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index 0967d5be..82ac8755 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -82,12 +82,12 @@ func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([] return cloneMemberRecords(items), nil } -// GetAllMembersByPlatformId implements chat.store.GetAllMembersByPlatformId -func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { +// GetAllMembersByPlatformIds implements chat.store.GetAllMembersByPlatformIds +func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[chat.Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { s.mu.Lock() defer s.mu.Unlock() - items := s.findMembersByPlatformId(platform, platformId) + items := s.findMembersByPlatformIds(idByPlatform) items, err := s.getMemberRecordPage(items, cursor, direction, limit) if err != nil { return nil, err @@ -350,10 +350,15 @@ func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { return res } -func (s *store) findMembersByPlatformId(platform chat.Platform, platformId string) []*chat.MemberRecord { +func (s *store) findMembersByPlatformIds(idByPlatform map[chat.Platform]string) []*chat.MemberRecord { var res []*chat.MemberRecord for _, item := range s.memberRecords { - if platform == item.Platform && platformId == item.PlatformId { + platformId, ok := idByPlatform[item.Platform] + if !ok { + continue + } + + if platformId == item.PlatformId { res = append(res, item) } } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index f3a4ffda..4aa3e1b7 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -34,9 +34,8 @@ type Store interface { // todo: Add paging when we introduce group chats GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) - // GetAllMembersByPlatformId gets all members for a given platform user across - // all chats - GetAllMembersByPlatformId(ctx context.Context, platform Platform, platformId string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + // GetAllMembersByPlatformIds gets all members for platform users across all chats + GetAllMembersByPlatformIds(ctx context.Context, idByPlatform map[Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) // GetAllMessagesByChatId gets all messages for a given chat // diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index d0cae22f..a0fcc965 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -400,7 +400,7 @@ type DatabaseData interface { GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) - GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) + GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error @@ -1479,12 +1479,12 @@ func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId cha func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { return dp.chatv2.GetAllMembersByChatId(ctx, chatId) } -func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { +func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { req, err := query.DefaultPaginationHandler(opts...) if err != nil { return nil, err } - return dp.chatv2.GetAllMembersByPlatformId(ctx, platform, platformId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv2.GetAllMembersByPlatformIds(ctx, idByPlatform, req.Cursor, req.SortBy, req.Limit) } func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { req, err := query.DefaultPaginationHandler(opts...) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index e19f0bc4..921d6b05 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1,7 +1,6 @@ package chat_v2 import ( - "bytes" "context" "crypto/rand" "database/sql" @@ -80,39 +79,6 @@ func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier return s } -func (s *server) setupMockChat() { - ctx := context.Background() - - chatId, _ := chat.GetChatIdFromString("c355fcec8c521e7937d45283d83bbfc63a0c688004f2386a535fc817218f917b") - chatRecord := &chat.ChatRecord{ - ChatId: chatId, - ChatType: chat.ChatTypeTwoWay, - IsVerified: true, - CreatedAt: time.Now(), - } - s.data.PutChatV2(ctx, chatRecord) - - memberId1, _ := chat.GetMemberIdFromString("034dda45-b4c2-45db-b1da-181298898a16") - memberRecord1 := &chat.MemberRecord{ - ChatId: chatId, - MemberId: memberId1, - Platform: chat.PlatformCode, - PlatformId: "8bw4gaRQk91w7vtgTN4E12GnKecY2y6CjPai7WUvWBQ8", - JoinedAt: time.Now(), - } - s.data.PutChatMemberV2(ctx, memberRecord1) - - memberId2, _ := chat.GetMemberIdFromString("a9d27058-f2d8-4034-bf52-b20c09a670de") - memberRecord2 := &chat.MemberRecord{ - ChatId: chatId, - MemberId: memberId2, - Platform: chat.PlatformCode, - PlatformId: "EDknQfoUnj73L56vKtEc6Qqw5VoHaF32eHYdz3V4y27M", - JoinedAt: time.Now(), - } - s.data.PutChatMemberV2(ctx, memberRecord2) -} - // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") @@ -158,10 +124,17 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch } } + myIdentities, err := s.getAllIdentities(ctx, owner) + if err != nil { + log.WithError(err).Warn("failure getting identities for owner account") + return nil, status.Error(codes.Internal, "") + } + + // todo: Use a better query that returns chat IDs. This will result in duplicate + // chat results if the user is in the chat multiple times across many identities. patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( ctx, - chat.PlatformCode, // todo: support other platforms once we support revealing identity - owner.PublicKey().ToBase58(), + myIdentities, query.WithCursor(cursor), query.WithDirection(direction), query.WithLimit(limit), @@ -185,87 +158,18 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch return nil, status.Error(codes.Internal, "") } - protoChat := &chatpb.ChatMetadata{ - ChatId: chatRecord.ChatId.ToProto(), - Kind: chatRecord.ChatType.ToProto(), - - IsMuted: platformUserMemberRecord.IsMuted, - IsSubscribed: !platformUserMemberRecord.IsUnsubscribed, - - Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))}, - } - - // Unread count calculations can be skipped for unsubscribed chats. They - // don't appear in chat history. - skipUnreadCountQuery := platformUserMemberRecord.IsUnsubscribed - - switch chatRecord.ChatType { - case chat.ChatTypeTwoWay: - protoChat.Title = "Mock Chat" // todo: proper title with localization - - protoChat.CanMute = true - protoChat.CanUnsubscribe = true - default: - return nil, status.Errorf(codes.Unimplemented, "unsupported chat type: %s", chatRecord.ChatType.String()) - } - - chatMemberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) + memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) if err != nil { log.WithError(err).Warn("failure getting chat members") return nil, status.Error(codes.Internal, "") } - for _, memberRecord := range chatMemberRecords { - var identity *chatpb.ChatMemberIdentity - switch memberRecord.Platform { - case chat.PlatformCode: - case chat.PlatformTwitter: - identity = &chatpb.ChatMemberIdentity{ - Platform: memberRecord.Platform.ToProto(), - Username: memberRecord.PlatformId, - } - default: - return nil, status.Errorf(codes.Unimplemented, "unsupported platform type: %s", memberRecord.Platform.String()) - } - - var pointers []*chatpb.Pointer - for _, optionalPointer := range []struct { - kind chat.PointerType - value *chat.MessageId - }{ - {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, - {chat.PointerTypeRead, memberRecord.ReadPointer}, - } { - if optionalPointer.value == nil { - continue - } - - pointers = append(pointers, &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberRecord.MemberId.ToProto(), - }) - } - - protoChat.Members = append(protoChat.Members, &chatpb.ChatMember{ - MemberId: memberRecord.MemberId.ToProto(), - IsSelf: bytes.Equal(memberRecord.MemberId[:], platformUserMemberRecord.MemberId[:]), - Identity: identity, - Pointers: pointers, - }) - } - if !skipUnreadCountQuery { - readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) - if platformUserMemberRecord.ReadPointer != nil { - readPointer = *platformUserMemberRecord.ReadPointer - } - unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, readPointer) - if err != nil { - log.WithError(err).Warn("failure getting unread count") - return nil, status.Error(codes.Internal, "") - } - protoChat.NumUnread = unreadCount + protoChat, err := s.toProtoChat(ctx, chatRecord, memberRecords, myIdentities) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") } + protoChat.Cursor = &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))} protoChats = append(protoChats, protoChat) } @@ -417,8 +321,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e signature := req.GetOpenStream().Signature req.GetOpenStream().Signature = nil - if err := s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { - return err + if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + // return err } _, err = s.data.GetChatByIdV2(ctx, chatId) @@ -616,7 +520,7 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * } func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { - log := s.log.WithField("method", "StartChat") + log := s.log.WithField("method", "SendMessage") log = client.InjectLoggingMetadata(ctx, log) owner, err := common.NewAccountFromProto(req.Owner) @@ -628,7 +532,7 @@ func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* signature := req.Signature req.Signature = nil - if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { + if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } @@ -1458,17 +1362,87 @@ func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o return res, nil } -func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: chatMessage, - }, +func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { + protoChat := &chatpb.ChatMetadata{ + ChatId: chatRecord.ChatId.ToProto(), + Kind: chatRecord.ChatType.ToProto(), } - if err := s.asyncNotifyAll(chatId, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") + + switch chatRecord.ChatType { + case chat.ChatTypeTwoWay: + protoChat.Title = "Tip Chat" // todo: proper title with localization + + protoChat.CanMute = true + protoChat.CanUnsubscribe = true + default: + return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) + } + + for _, memberRecord := range memberRecords { + var isSelf bool + var identity *chatpb.ChatMemberIdentity + switch memberRecord.Platform { + case chat.PlatformCode: + myPublicKey, ok := myIdentitiesByPlatform[chat.PlatformCode] + isSelf = ok && myPublicKey == memberRecord.PlatformId + case chat.PlatformTwitter: + myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] + isSelf = ok && myTwitterUsername == memberRecord.PlatformId + + identity = &chatpb.ChatMemberIdentity{ + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + } + default: + return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) + } + + var pointers []*chatpb.Pointer + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + pointers = append(pointers, &chatpb.Pointer{ + Kind: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberRecord.MemberId.ToProto(), + }) + } + + protoMember := &chatpb.ChatMember{ + MemberId: memberRecord.MemberId.ToProto(), + IsSelf: isSelf, + Identity: identity, + Pointers: pointers, + } + if protoMember.IsSelf { + protoMember.IsMuted = memberRecord.IsMuted + protoMember.IsSubscribed = !memberRecord.IsUnsubscribed + + if !memberRecord.IsUnsubscribed { + readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) + if memberRecord.ReadPointer != nil { + readPointer = *memberRecord.ReadPointer + } + unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, memberRecord.MemberId, readPointer) + if err != nil { + return nil, errors.Wrap(err, "error calculating unread count") + } + protoMember.NumUnread = unreadCount + } + } + + protoChat.Members = append(protoChat.Members, protoMember) } - // todo: send the push + return protoChat, nil } func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { @@ -1487,7 +1461,7 @@ func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (m return identities, nil } -func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { +func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { case nil: @@ -1547,16 +1521,3 @@ func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Acco return "", false, errors.Wrap(err, "error getting twitter user") } } - -func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.ChatMessage { - messageId := chat.GenerateMessageId() - ts, _ := messageId.GetTimestamp() - - return &chatpb.ChatMessage{ - MessageId: messageId.ToProto(), - SenderId: sender.ToProto(), - Content: content, - Ts: timestamppb.New(ts), - Cursor: &chatpb.Cursor{Value: messageId[:]}, - } -} From 55c9cf7cc0cd8c5a44bb9ae55efa64deb81b35ea Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Tue, 18 Jun 2024 13:32:55 -0400 Subject: [PATCH 50/71] Initial implementation of the RevealIdentity RPC --- pkg/code/data/chat/v2/model.go | 16 ++++++++-- pkg/code/server/grpc/chat/v2/server.go | 41 ++++++++++++++++---------- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index e0cb5ed9..ef3c7071 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -192,12 +192,22 @@ func (p PointerType) String() string { } // ToProto returns the proto representation of the platform -func (p Platform) ToProto() chatpb.ChatMemberIdentity_Platform { +func GetPlatformFromProto(proto chatpb.Platform) Platform { + switch proto { + case chatpb.Platform_TWITTER: + return PlatformTwitter + default: + return PlatformUnknown + } +} + +// ToProto returns the proto representation of the platform +func (p Platform) ToProto() chatpb.Platform { switch p { case PlatformTwitter: - return chatpb.ChatMemberIdentity_TWITTER + return chatpb.Platform_TWITTER default: - return chatpb.ChatMemberIdentity_UNKNOWN + return chatpb.Platform_UNKNOWN_PLATFORM } } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 921d6b05..b57a2b65 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -520,7 +520,7 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * } func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { - log := s.log.WithField("method", "SendMessage") + log := s.log.WithField("method", "StartChat") log = client.InjectLoggingMetadata(ctx, log) owner, err := common.NewAccountFromProto(req.Owner) @@ -1043,17 +1043,20 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR chatLock.Lock() defer chatLock.Unlock() - chatMessage := newProtoChatMessage( - memberId, - &chatpb.Content{ - Type: &chatpb.Content_IdentityRevealed{ - IdentityRevealed: &chatpb.IdentityRevealedContent{ - MemberId: req.MemberId, - Identity: req.Identity, - }, + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + chatMessage := &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: req.MemberId, + Content: []*chatpb.Content{ + { + Type: &chatpb.Content_IdentityRevealed{}, }, }, - ) + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) @@ -1073,14 +1076,22 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR }) if err == nil { - s.onPersistChatMessage(log, chatId, chatMessage) + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, + } + if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") + } + + // todo: send the push } switch err { case nil: return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_OK, - Message: chatMessage, + Result: chatpb.RevealIdentityResponse_OK, }, nil case chat.ErrMemberIdentityAlreadyUpgraded: return &chatpb.RevealIdentityResponse{ @@ -1140,7 +1151,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque if err != nil { log.WithError(err).Warn("failure determing chat member ownership") return nil, status.Error(codes.Internal, "") - } else if !isChatMember { + } else if !ownsChatMember { return &chatpb.SetMuteStateResponse{ Result: chatpb.SetMuteStateResponse_DENIED, }, nil @@ -1461,7 +1472,7 @@ func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (m return identities, nil } -func (s *server) ownsChatMember(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { +func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { case nil: From 516a5021e8ebd666c7a05f5a8b155cdf0ab0cd1e Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Wed, 19 Jun 2024 09:54:18 -0400 Subject: [PATCH 51/71] More refactors of chat stuff --- pkg/code/server/grpc/chat/v2/server.go | 141 +++++++------------------ 1 file changed, 36 insertions(+), 105 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index b57a2b65..12757839 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -321,8 +321,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e signature := req.GetOpenStream().Signature req.GetOpenStream().Signature = nil - if err = s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { - // return err + if err := s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { + return err } _, err = s.data.GetChatByIdV2(ctx, chatId) @@ -532,7 +532,7 @@ func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* signature := req.Signature req.Signature = nil - if err = s.auth.Authenticate(ctx, owner, req, signature); err != nil { + if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } @@ -1043,20 +1043,17 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR chatLock.Lock() defer chatLock.Unlock() - messageId := chat.GenerateMessageId() - ts, _ := messageId.GetTimestamp() - - chatMessage := &chatpb.ChatMessage{ - MessageId: messageId.ToProto(), - SenderId: req.MemberId, - Content: []*chatpb.Content{ - { - Type: &chatpb.Content_IdentityRevealed{}, + chatMessage := newProtoChatMessage( + memberId, + &chatpb.Content{ + Type: &chatpb.Content_IdentityRevealed{ + IdentityRevealed: &chatpb.IdentityRevealedContent{ + MemberId: req.MemberId, + Identity: req.Identity, + }, }, }, - Ts: timestamppb.New(ts), - Cursor: &chatpb.Cursor{Value: messageId[:]}, - } + ) err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) @@ -1076,16 +1073,7 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR }) if err == nil { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: chatMessage, - }, - } - if err := s.asyncNotifyAll(chatId, memberId, event); err != nil { - log.WithError(err).Warn("failure notifying chat event") - } - - // todo: send the push + s.onPersistChatMessage(log, chatId, chatMessage) } switch err { @@ -1236,7 +1224,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { protoChat := &chatpb.ChatMetadata{ ChatId: chatRecord.ChatId.ToProto(), - Type: chatRecord.ChatType.ToProto(), + Kind: chatRecord.ChatType.ToProto(), Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, } @@ -1282,7 +1270,7 @@ func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m } pointers = append(pointers, &chatpb.Pointer{ - Type: optionalPointer.kind.ToProto(), + Kind: optionalPointer.kind.ToProto(), Value: optionalPointer.value.ToProto(), MemberId: memberRecord.MemberId.ToProto(), }) @@ -1373,87 +1361,17 @@ func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o return res, nil } -func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { - protoChat := &chatpb.ChatMetadata{ - ChatId: chatRecord.ChatId.ToProto(), - Kind: chatRecord.ChatType.ToProto(), - } - - switch chatRecord.ChatType { - case chat.ChatTypeTwoWay: - protoChat.Title = "Tip Chat" // todo: proper title with localization - - protoChat.CanMute = true - protoChat.CanUnsubscribe = true - default: - return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) +func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: chatMessage, + }, } - - for _, memberRecord := range memberRecords { - var isSelf bool - var identity *chatpb.ChatMemberIdentity - switch memberRecord.Platform { - case chat.PlatformCode: - myPublicKey, ok := myIdentitiesByPlatform[chat.PlatformCode] - isSelf = ok && myPublicKey == memberRecord.PlatformId - case chat.PlatformTwitter: - myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] - isSelf = ok && myTwitterUsername == memberRecord.PlatformId - - identity = &chatpb.ChatMemberIdentity{ - Platform: memberRecord.Platform.ToProto(), - Username: memberRecord.PlatformId, - } - default: - return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) - } - - var pointers []*chatpb.Pointer - for _, optionalPointer := range []struct { - kind chat.PointerType - value *chat.MessageId - }{ - {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, - {chat.PointerTypeRead, memberRecord.ReadPointer}, - } { - if optionalPointer.value == nil { - continue - } - - pointers = append(pointers, &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberRecord.MemberId.ToProto(), - }) - } - - protoMember := &chatpb.ChatMember{ - MemberId: memberRecord.MemberId.ToProto(), - IsSelf: isSelf, - Identity: identity, - Pointers: pointers, - } - if protoMember.IsSelf { - protoMember.IsMuted = memberRecord.IsMuted - protoMember.IsSubscribed = !memberRecord.IsUnsubscribed - - if !memberRecord.IsUnsubscribed { - readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) - if memberRecord.ReadPointer != nil { - readPointer = *memberRecord.ReadPointer - } - unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, memberRecord.MemberId, readPointer) - if err != nil { - return nil, errors.Wrap(err, "error calculating unread count") - } - protoMember.NumUnread = unreadCount - } - } - - protoChat.Members = append(protoChat.Members, protoMember) + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - return protoChat, nil + // todo: send the push } func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { @@ -1532,3 +1450,16 @@ func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Acco return "", false, errors.Wrap(err, "error getting twitter user") } } + +func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.ChatMessage { + messageId := chat.GenerateMessageId() + ts, _ := messageId.GetTimestamp() + + return &chatpb.ChatMessage{ + MessageId: messageId.ToProto(), + SenderId: sender.ToProto(), + Content: content, + Ts: timestamppb.New(ts), + Cursor: &chatpb.Cursor{Value: messageId[:]}, + } +} From 8b228722a606b0f9a14464d25f0f056443cb7d89 Mon Sep 17 00:00:00 2001 From: Jeff Yanta Date: Thu, 20 Jun 2024 13:58:28 -0400 Subject: [PATCH 52/71] Incorporate small tweaks to chat APIs --- pkg/code/server/grpc/chat/v2/server.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 12757839..79200f98 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1079,7 +1079,8 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR switch err { case nil: return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_OK, + Result: chatpb.RevealIdentityResponse_OK, + Message: chatMessage, }, nil case chat.ErrMemberIdentityAlreadyUpgraded: return &chatpb.RevealIdentityResponse{ @@ -1224,7 +1225,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { protoChat := &chatpb.ChatMetadata{ ChatId: chatRecord.ChatId.ToProto(), - Kind: chatRecord.ChatType.ToProto(), + Type: chatRecord.ChatType.ToProto(), Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, } @@ -1270,7 +1271,7 @@ func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m } pointers = append(pointers, &chatpb.Pointer{ - Kind: optionalPointer.kind.ToProto(), + Type: optionalPointer.kind.ToProto(), Value: optionalPointer.value.ToProto(), MemberId: memberRecord.MemberId.ToProto(), }) From afb13a0f03459760c55a23f7ebd76cc1aee8cd97 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 8 Jul 2024 16:26:52 -0400 Subject: [PATCH 53/71] chats: add (temporary) plumbing for tip notifications. Adds enough plumbing for the tip notifications to flow through v2 (without breaking the previous system, hopefully). When the protocol is more flushed out, and the base messaging infra is shored up a bit, this will likely be revisted --- pkg/code/async/geyser/messenger.go | 2 +- pkg/code/chat/message_cash_transactions.go | 2 +- pkg/code/chat/message_code_team.go | 2 +- pkg/code/chat/message_kin_purchases.go | 2 +- pkg/code/chat/message_merchant.go | 2 +- pkg/code/chat/message_tips.go | 66 +++++++- pkg/code/chat/sender.go | 141 +++++++++++++++++- pkg/code/chat/sender_test.go | 13 +- pkg/code/data/chat/v2/id.go | 11 ++ pkg/code/server/grpc/chat/v1/server_test.go | 4 +- pkg/code/server/grpc/chat/v2/notifier.go | 22 +++ pkg/code/server/grpc/chat/v2/server.go | 43 ++++++ .../grpc/transaction/v2/history_test.go | 56 +++++++ pkg/code/server/grpc/transaction/v2/intent.go | 2 +- pkg/code/server/grpc/transaction/v2/server.go | 8 +- .../server/grpc/transaction/v2/testutil.go | 10 ++ 16 files changed, 357 insertions(+), 29 deletions(-) create mode 100644 pkg/code/server/grpc/chat/v2/notifier.go diff --git a/pkg/code/async/geyser/messenger.go b/pkg/code/async/geyser/messenger.go index afb8e999..c6b7f005 100644 --- a/pkg/code/async/geyser/messenger.go +++ b/pkg/code/async/geyser/messenger.go @@ -165,7 +165,7 @@ func processPotentialBlockchainMessage(ctx context.Context, data code_data.Provi return errors.Wrap(err, "error creating proto message") } - canPush, err := chat_util.SendChatMessage( + canPush, err := chat_util.SendNotificationChatMessageV1( ctx, data, asciiBaseDomain, diff --git a/pkg/code/chat/message_cash_transactions.go b/pkg/code/chat/message_cash_transactions.go index 8a94361f..620f32fb 100644 --- a/pkg/code/chat/message_cash_transactions.go +++ b/pkg/code/chat/message_cash_transactions.go @@ -148,7 +148,7 @@ func SendCashTransactionsExchangeMessage(ctx context.Context, data code_data.Pro return errors.Wrap(err, "error creating proto chat message") } - _, err = SendChatMessage( + _, err = SendNotificationChatMessageV1( ctx, data, CashTransactionsName, diff --git a/pkg/code/chat/message_code_team.go b/pkg/code/chat/message_code_team.go index d24e2305..423781cf 100644 --- a/pkg/code/chat/message_code_team.go +++ b/pkg/code/chat/message_code_team.go @@ -16,7 +16,7 @@ import ( // SendCodeTeamMessage sends a message to the Code Team chat. func SendCodeTeamMessage(ctx context.Context, data code_data.Provider, receiver *common.Account, chatMessage *chatpb.ChatMessage) (bool, error) { - return SendChatMessage( + return SendNotificationChatMessageV1( ctx, data, CodeTeamName, diff --git a/pkg/code/chat/message_kin_purchases.go b/pkg/code/chat/message_kin_purchases.go index 4377c247..1e637355 100644 --- a/pkg/code/chat/message_kin_purchases.go +++ b/pkg/code/chat/message_kin_purchases.go @@ -23,7 +23,7 @@ func GetKinPurchasesChatId(owner *common.Account) chat_v1.ChatId { // SendKinPurchasesMessage sends a message to the Kin Purchases chat. func SendKinPurchasesMessage(ctx context.Context, data code_data.Provider, receiver *common.Account, chatMessage *chatpb.ChatMessage) (bool, error) { - return SendChatMessage( + return SendNotificationChatMessageV1( ctx, data, KinPurchasesName, diff --git a/pkg/code/chat/message_merchant.go b/pkg/code/chat/message_merchant.go index a340f8bd..3acf3ed4 100644 --- a/pkg/code/chat/message_merchant.go +++ b/pkg/code/chat/message_merchant.go @@ -161,7 +161,7 @@ func SendMerchantExchangeMessage(ctx context.Context, data code_data.Provider, i return nil, errors.Wrap(err, "error creating proto chat message") } - canPush, err := SendChatMessage( + canPush, err := SendNotificationChatMessageV1( ctx, data, chatTitle, diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index 5752205a..6dceb051 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -2,15 +2,23 @@ package chat import ( "context" + "fmt" + "github.com/mr-tron/base58" "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" + chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" ) // SendTipsExchangeMessage sends a message to the Tips chat with exchange data @@ -18,7 +26,12 @@ import ( // Tips chat will be ignored. // // Note: Tests covered in SubmitIntent history tests -func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, intentRecord *intent.Record) ([]*MessageWithOwner, error) { +func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notifier chat_server.Notifier, intentRecord *intent.Record) ([]*MessageWithOwner, error) { + intentIdRaw, err := base58.Decode(intentRecord.IntentId) + if err != nil { + return nil, fmt.Errorf("invalid intent id: %w", err) + } + messageId := intentRecord.IntentId exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -30,7 +43,6 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten switch intentRecord.IntentType { case intent.SendPrivatePayment: if !intentRecord.SendPrivatePaymentMetadata.IsTip { - // Not a tip return nil, nil } @@ -61,30 +73,68 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, inten }, }, } - protoMessage, err := newProtoChatMessage(messageId, content, intentRecord.CreatedAt) + + v1Message, err := newProtoChatMessage(messageId, content, intentRecord.CreatedAt) if err != nil { return nil, errors.Wrap(err, "error creating proto chat message") } - canPush, err := SendChatMessage( + v2Message := &chatv2pb.ChatMessage{ + MessageId: chat_v2.GenerateMessageId().ToProto(), + Content: []*chatv2pb.Content{ + { + Type: &chatv2pb.Content_ExchangeData{ + ExchangeData: &chatv2pb.ExchangeDataContent{ + Verb: chatv2pb.ExchangeDataContent_Verb(verb), + ExchangeData: &chatv2pb.ExchangeDataContent_Exact{ + Exact: exchangeData, + }, + Reference: &chatv2pb.ExchangeDataContent_Intent{ + Intent: &commonpb.IntentId{Value: intentIdRaw}, + }, + }, + }, + }, + }, + Ts: timestamppb.New(intentRecord.CreatedAt), + } + + canPush, err := SendNotificationChatMessageV1( ctx, data, TipsName, chat_v1.ChatTypeInternal, true, receiver, - protoMessage, + v1Message, verb != chatpb.ExchangeDataContent_RECEIVED_TIP, ) - if err != nil && err != chat_v1.ErrMessageAlreadyExists { - return nil, errors.Wrap(err, "error persisting chat message") + if err != nil && !errors.Is(err, chat_v1.ErrMessageAlreadyExists) { + return nil, errors.Wrap(err, "error persisting v1 chat message") + } + + _, err = SendNotificationChatMessageV2( + ctx, + data, + notifier, + TipsName, + true, + receiver, + v2Message, + intentRecord.IntentId, + verb != chatpb.ExchangeDataContent_RECEIVED_TIP, + ) + if err != nil { + // TODO: Eventually we'll want to return an error, but for now we'll log + // since we're not in 'prod' yet. + logrus.StandardLogger().WithError(err).Warn("Failed to send notification message (v2)") } if canPush { messagesToPush = append(messagesToPush, &MessageWithOwner{ Owner: receiver, Title: TipsName, - Message: protoMessage, + Message: v1Message, }) } } diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 412e656b..571d8a6e 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -2,25 +2,30 @@ package chat import ( "context" + "database/sql" "errors" + "fmt" "time" "github.com/mr-tron/base58" "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" + chatserver "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" ) -// SendChatMessage sends a chat message to a receiving owner account. +// SendNotificationChatMessageV1 sends a chat message to a receiving owner account. // // Note: This function is not responsible for push notifications. This method // might be called within the context of a DB transaction, which might have -// unrelated failures. A hint as to whether a push should be sent is provided. -func SendChatMessage( +// unrelated failures. A hint whether a push should be sent is provided. +func SendNotificationChatMessageV1( ctx context.Context, data code_data.Provider, chatTitle string, @@ -80,7 +85,7 @@ func SendChatMessage( } err = data.PutChatV1(ctx, chatRecord) - if err != nil && err != chat_v1.ErrChatAlreadyExists { + if err != nil && !errors.Is(err, chat_v1.ErrChatAlreadyExists) { return false, err } default: @@ -115,3 +120,131 @@ func SendChatMessage( return canPushMessage, nil } + +func SendNotificationChatMessageV2( + ctx context.Context, + data code_data.Provider, + notifier chatserver.Notifier, + chatTitle string, + isVerifiedChat bool, + receiver *common.Account, + protoMessage *chatv2pb.ChatMessage, + intentId string, + isSilentMessage bool, +) (canPushMessage bool, err error) { + chatId := chat_v2.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) + + if protoMessage.Cursor != nil { + // Let the utilities and GetMessages RPC handle cursors + return false, errors.New("cursor must not be set") + } + + if err := protoMessage.Validate(); err != nil { + return false, err + } + + messageId, err := chat_v2.GetMessageIdFromProto(protoMessage.MessageId) + if err != nil { + return false, fmt.Errorf("invalid message id: %w", err) + } + + // Clear out extracted metadata as a space optimization + cloned := proto.Clone(protoMessage).(*chatv2pb.ChatMessage) + cloned.MessageId = nil + cloned.Ts = nil + cloned.Cursor = nil + + marshalled, err := proto.Marshal(cloned) + if err != nil { + return false, err + } + + canPersistMessage := true + canPushMessage = !isSilentMessage + + // + // Step 1: Check to see if we need to create the chat. + // + _, err = data.GetChatByIdV2(ctx, chatId) + if errors.Is(err, chat_v2.ErrChatNotFound) { + chatRecord := &chat_v2.ChatRecord{ + ChatId: chatId, + ChatType: chat_v2.ChatTypeNotification, + ChatTitle: &chatTitle, + IsVerified: isVerifiedChat, + + CreatedAt: time.Now(), + } + + err = data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + err = data.PutChatV2(ctx, chatRecord) + if err != nil && !errors.Is(err, chat_v2.ErrChatExists) { + return fmt.Errorf("failed to initialize chat: %w", err) + } + + err = data.PutChatMemberV2(ctx, &chat_v2.MemberRecord{ + ChatId: chatId, + MemberId: chat_v2.GenerateMemberId(), + Platform: chat_v2.PlatformCode, + PlatformId: receiver.PublicKey().ToBase58(), + JoinedAt: time.Now(), + }) + if err != nil { + return fmt.Errorf("failed to initialize chat with member: %w", err) + } + + return nil + }) + if err != nil { + return false, err + } + } else if err != nil { + return false, err + } + + // + // Step 2: Ensure that there is exactly 1 member in the chat. + // + members, err := data.GetAllChatMembersV2(ctx, chatId) + if errors.Is(err, chat_v2.ErrMemberNotFound) { // TODO: This is a weird error... + return false, nil + } else if err != nil { + return false, err + } + if len(members) > 1 { + // TODO: This _could_ get weird if client or someone else decides to join as another member. + return false, errors.New("notification chat should have at most 1 member") + } + + canPersistMessage = !members[0].IsUnsubscribed + canPushMessage = canPushMessage && canPersistMessage && !members[0].IsMuted + + if canPersistMessage { + refType := chat_v2.ReferenceTypeIntent + messageRecord := &chat_v2.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + + Data: marshalled, + IsSilent: isSilentMessage, + + ReferenceType: &refType, + Reference: &intentId, + } + + // TODO: Once we have a better idea on the data modeling around chatv2, + // we may wish to have the server manage the creation of messages + // (and chats?) as well. That would also put the + err = data.PutChatMessageV2(ctx, messageRecord) + if err != nil { + return false, err + } + + notifier.NotifyMessage(ctx, chatId, protoMessage) + } + + // TODO: Once we move more things over to chatv2, we will need to increment + // badge count here. We don't currently, as it would result in a double + // push. + return canPushMessage, nil +} diff --git a/pkg/code/chat/sender_test.go b/pkg/code/chat/sender_test.go index ac767fd9..18028c43 100644 --- a/pkg/code/chat/sender_test.go +++ b/pkg/code/chat/sender_test.go @@ -33,9 +33,8 @@ func TestSendChatMessage_HappyPath(t *testing.T) { chatMessage := newRandomChatMessage(t, i+1) expectedBadgeCount += 1 - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) - assert.True(t, canPush) assert.NotNil(t, chatMessage.MessageId) @@ -56,7 +55,7 @@ func TestSendChatMessage_VerifiedChat(t *testing.T) { for _, isVerified := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - _, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, isVerified, receiver, chatMessage, true) + _, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, isVerified, receiver, chatMessage, true) require.NoError(t, err) env.assertChatRecordSaved(t, chatTitle, receiver, isVerified) } @@ -71,7 +70,7 @@ func TestSendChatMessage_SilentMessage(t *testing.T) { for i, isSilent := range []bool{true, false} { chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, isSilent) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, isSilent) require.NoError(t, err) assert.Equal(t, !isSilent, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, isSilent) @@ -92,7 +91,7 @@ func TestSendChatMessage_MuteState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isMuted, canPush) env.assertChatMessageRecordSaved(t, chatId, chatMessage, false) @@ -113,7 +112,7 @@ func TestSendChatMessage_SubscriptionState(t *testing.T) { } chatMessage := newRandomChatMessage(t, 1) - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) require.NoError(t, err) assert.Equal(t, !isUnsubscribed, canPush) if isUnsubscribed { @@ -135,7 +134,7 @@ func TestSendChatMessage_InvalidProtoMessage(t *testing.T) { chatMessage := newRandomChatMessage(t, 1) chatMessage.Content = nil - canPush, err := SendChatMessage(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) + canPush, err := SendNotificationChatMessageV1(env.ctx, env.data, chatTitle, chat_v1.ChatTypeInternal, true, receiver, chatMessage, false) assert.Error(t, err) assert.False(t, canPush) env.assertChatRecordNotSaved(t, chatId) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index de912e48..5f5d54cb 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -2,7 +2,10 @@ package chat_v2 import ( "bytes" + "crypto/sha256" "encoding/hex" + "fmt" + "strings" "time" "github.com/google/uuid" @@ -29,6 +32,14 @@ func GetChatIdFromBytes(buffer []byte) (ChatId, error) { return typed, nil } +func GetChatId(sender, receiver string, isVerified bool) ChatId { + combined := []byte(fmt.Sprintf("%s:%s:%v", sender, receiver, isVerified)) + if strings.Compare(sender, receiver) > 0 { + combined = []byte(fmt.Sprintf("%s:%s:%v", receiver, sender, isVerified)) + } + return sha256.Sum256(combined) +} + // GetChatIdFromBytes gets a chat ID from the string representation func GetChatIdFromString(value string) (ChatId, error) { decoded, err := hex.DecodeString(value) diff --git a/pkg/code/server/grpc/chat/v1/server_test.go b/pkg/code/server/grpc/chat/v1/server_test.go index 0a9abad4..fe9eaf2b 100644 --- a/pkg/code/server/grpc/chat/v1/server_test.go +++ b/pkg/code/server/grpc/chat/v1/server_test.go @@ -894,12 +894,12 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { } func (e *testEnv) sendExternalAppChatMessage(t *testing.T, msg *chatpb.ChatMessage, domain string, isVerified bool, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, domain, chat.ChatTypeExternalApp, isVerified, recipient, msg, false) + _, err := chat_util.SendNotificationChatMessageV1(e.ctx, e.data, domain, chat.ChatTypeExternalApp, isVerified, recipient, msg, false) require.NoError(t, err) } func (e *testEnv) sendInternalChatMessage(t *testing.T, msg *chatpb.ChatMessage, chatTitle string, recipient *common.Account) { - _, err := chat_util.SendChatMessage(e.ctx, e.data, chatTitle, chat.ChatTypeInternal, true, recipient, msg, false) + _, err := chat_util.SendNotificationChatMessageV1(e.ctx, e.data, chatTitle, chat.ChatTypeInternal, true, recipient, msg, false) require.NoError(t, err) } diff --git a/pkg/code/server/grpc/chat/v2/notifier.go b/pkg/code/server/grpc/chat/v2/notifier.go new file mode 100644 index 00000000..994d95e9 --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/notifier.go @@ -0,0 +1,22 @@ +package chat_v2 + +import ( + "context" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" +) + +type Notifier interface { + NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) +} + +type NoopNotifier struct{} + +func NewNoopNotifier() *NoopNotifier { + return &NoopNotifier{} +} + +func (n *NoopNotifier) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.ChatMessage) { +} diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 79200f98..998db1ad 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -12,6 +12,7 @@ import ( "github.com/mr-tron/base58" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" "golang.org/x/text/language" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -783,6 +784,48 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest }, nil } +// TODO(api): This likely needs an RPC that can be called from any other server. +func (s *server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) error { + members, err := s.data.GetAllChatMembersV2(ctx, chatID) + if errors.Is(err, chat.ErrMemberNotFound) { + return nil + } else if err != nil { + return err + } + + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{Message: message}, + } + + var eg errgroup.Group + eg.SetLimit(min(32, len(members))) + + for _, m := range members { + eg.Go(func() error { + streamKey := fmt.Sprintf("%s:%s", chatID, m.MemberId.String()) + s.streamsMu.RLock() + stream := s.streams[streamKey] + s.streamsMu.RUnlock() + + if stream == nil { + return nil + } + + if err = stream.notify(event, time.Second); err != nil { + s.log.WithError(err). + WithField("member", m.MemberId.String()). + Info("Failed to notify chat stream") + } + + return nil + }) + } + + _ = eg.Wait() + + return nil +} + // todo: This belongs in the common chat utility, which currently only operates on v1 chats func (s *server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { if err := protoChatMessage.Validate(); err != nil { diff --git a/pkg/code/server/grpc/transaction/v2/history_test.go b/pkg/code/server/grpc/transaction/v2/history_test.go index 80ae0442..3ad912b7 100644 --- a/pkg/code/server/grpc/transaction/v2/history_test.go +++ b/pkg/code/server/grpc/transaction/v2/history_test.go @@ -3,6 +3,7 @@ package transaction_v2 import ( "testing" + "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -13,6 +14,7 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" @@ -338,6 +340,10 @@ func TestPaymentHistory_HappyPath(t *testing.T) { require.NoError(t, err) require.Len(t, chatMessageRecords, 1) + chatMessageRecordsV2, err := server.data.GetAllChatMessagesV2(server.ctx, chat_v2.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) + require.NoError(t, err) + requireEquivalent(t, chatMessageRecords, chatMessageRecordsV2) + protoChatMessage = getProtoChatMessage(t, chatMessageRecords[0]) require.Len(t, protoChatMessage.Content, 1) require.NotNil(t, protoChatMessage.Content[0].GetExchangeData()) @@ -413,6 +419,10 @@ func TestPaymentHistory_HappyPath(t *testing.T) { require.NoError(t, err) require.Len(t, chatMessageRecords, 1) + chatMessageRecordsV2, err = server.data.GetAllChatMessagesV2(server.ctx, chat_v2.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) + require.NoError(t, err) + requireEquivalent(t, chatMessageRecords, chatMessageRecordsV2) + protoChatMessage = getProtoChatMessage(t, chatMessageRecords[0]) require.Len(t, protoChatMessage.Content, 1) require.NotNil(t, protoChatMessage.Content[0].GetExchangeData()) @@ -422,3 +432,49 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 45.6, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(456), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) } + +func requireEquivalent(t *testing.T, v1 []*chat_v1.Message, v2 []*chat_v2.MessageRecord) { + require.Equal(t, len(v1), len(v2)) + + for i, v1Record := range v1 { + v2Record := v2[i] + + require.Equal(t, v1Record.ChatId[:], v2Record.ChatId[:]) + require.Equal(t, v1Record.IsSilent, v2Record.IsSilent) + + v1Message := getProtoChatMessage(t, v1Record) + require.Empty(t, v1Message.Sender) + + v2Message := getProtoChatMessageV2(t, v2Record) + require.Empty(t, v2Message.SenderId) + + require.Equal(t, len(v1Message.Content), len(v2Message.Content)) + + // TODO: Move this to somewhere common? + for c := range v1Message.Content { + a := v1Message.Content[c].GetExchangeData() + require.NotNil(t, a) + + b := v2Message.Content[c].GetExchangeData() + require.NotNil(t, b) + + require.EqualValues(t, a.Verb, b.Verb) + + if a.GetExact() != nil { + require.Equal(t, a.GetExact().Currency, b.GetExact().Currency) + require.Equal(t, a.GetExact().ExchangeRate, b.GetExact().ExchangeRate) + require.Equal(t, a.GetExact().NativeAmount, b.GetExact().NativeAmount) + require.Equal(t, a.GetExact().Quarks, b.GetExact().Quarks) + } else if a.GetPartial() != nil { + require.Equal(t, a.GetPartial().Currency, b.GetPartial().Currency) + require.Equal(t, a.GetPartial().NativeAmount, b.GetPartial().NativeAmount) + } else { + t.Fatal("Unhandled case") + } + + intent := b.GetIntent() + require.NotNil(t, intent) + require.Equal(t, v1Record.MessageId, base58.Encode(intent.Value)) + } + } +} diff --git a/pkg/code/server/grpc/transaction/v2/intent.go b/pkg/code/server/grpc/transaction/v2/intent.go index 9faffc47..5f0db81e 100644 --- a/pkg/code/server/grpc/transaction/v2/intent.go +++ b/pkg/code/server/grpc/transaction/v2/intent.go @@ -888,7 +888,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm return err } - tipMessagesToPush, err := chat_util.SendTipsExchangeMessage(ctx, s.data, intentRecord) + tipMessagesToPush, err := chat_util.SendTipsExchangeMessage(ctx, s.data, s.notifier, intentRecord) if err != nil { log.WithError(err).Warn("failure updating tips chat") return err diff --git a/pkg/code/server/grpc/transaction/v2/server.go b/pkg/code/server/grpc/transaction/v2/server.go index 3b852327..2bbbb551 100644 --- a/pkg/code/server/grpc/transaction/v2/server.go +++ b/pkg/code/server/grpc/transaction/v2/server.go @@ -14,6 +14,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/lawenforcement" + chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" "github.com/code-payments/code-server/pkg/code/server/grpc/messaging" "github.com/code-payments/code-server/pkg/jupiter" "github.com/code-payments/code-server/pkg/kin" @@ -29,7 +30,8 @@ type transactionServer struct { auth *auth_util.RPCSignatureVerifier - pusher push_lib.Provider + pusher push_lib.Provider + notifier chat_server.Notifier jupiterClient *jupiter.Client @@ -65,6 +67,7 @@ type transactionServer struct { func NewTransactionServer( data code_data.Provider, pusher push_lib.Provider, + notifier chat_server.Notifier, jupiterClient *jupiter.Client, messagingClient messaging.InternalMessageClient, maxmind *maxminddb.Reader, @@ -85,7 +88,8 @@ func NewTransactionServer( auth: auth_util.NewRPCSignatureVerifier(data), - pusher: pusher, + pusher: pusher, + notifier: notifier, jupiterClient: jupiterClient, diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 1da54d65..3941af36 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -25,6 +25,7 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" @@ -36,6 +37,7 @@ import ( "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/currency" "github.com/code-payments/code-server/pkg/code/data/deposit" @@ -55,6 +57,7 @@ import ( user_identity "github.com/code-payments/code-server/pkg/code/data/user/identity" "github.com/code-payments/code-server/pkg/code/data/vault" exchange_rate_util "github.com/code-payments/code-server/pkg/code/exchangerate" + chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" "github.com/code-payments/code-server/pkg/code/server/grpc/messaging" transaction_util "github.com/code-payments/code-server/pkg/code/transaction" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -184,6 +187,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, testService := NewTransactionServer( db, memory_push.NewPushProvider(), + chat_server.NewNoopNotifier(), nil, messaging.NewMessagingClient(db), nil, @@ -6178,3 +6182,9 @@ func getProtoChatMessage(t *testing.T, record *chat_v1.Message) *chatpb.ChatMess require.NoError(t, proto.Unmarshal(record.Data, &protoMessage)) return &protoMessage } + +func getProtoChatMessageV2(t *testing.T, record *chat_v2.MessageRecord) *chatv2pb.ChatMessage { + protoMessage := &chatv2pb.ChatMessage{} + require.NoError(t, proto.Unmarshal(record.Data, protoMessage)) + return protoMessage +} From 179e98f4e19eb857128b8e6db4accfe870a1d9fd Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 8 Jul 2024 16:50:43 -0400 Subject: [PATCH 54/71] chatv/2: expose server for use as a notifier --- pkg/code/server/grpc/chat/v2/server.go | 56 +++++++++++++------------- pkg/code/server/grpc/chat/v2/stream.go | 4 +- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 998db1ad..3efc775f 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -45,7 +45,7 @@ const ( flushMessageCount = 100 ) -type server struct { +type Server struct { log *logrus.Entry data code_data.Provider @@ -60,9 +60,9 @@ type server struct { chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - s := &server{ - log: logrus.StandardLogger().WithField("type", "chat/v2/server"), +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) *Server { + s := &Server{ + log: logrus.StandardLogger().WithField("type", "chat/v2/Server"), data: data, auth: auth, @@ -81,7 +81,7 @@ func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier } // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership -func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { +func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { log := s.log.WithField("method", "GetChats") log = client.InjectLoggingMetadata(ctx, log) @@ -181,7 +181,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch }, nil } -func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { +func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest) (*chatpb.GetMessagesResponse, error) { log := s.log.WithField("method", "GetMessages") log = client.InjectLoggingMetadata(ctx, log) @@ -284,7 +284,7 @@ func (s *server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest }, nil } -func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { +func (s *Server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { ctx := streamer.Context() log := s.log.WithField("method", "StreamChatEvents") @@ -359,9 +359,9 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e stream, exists := s.streams[streamKey] if exists { s.streamsMu.Unlock() - // There's an existing stream on this server that must be terminated first. + // There's an existing stream on this Server that must be terminated first. // Warn to see how often this happens in practice - log.Warnf("existing stream detected on this server (stream=%p) ; aborting", stream) + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) return status.Error(codes.Aborted, "stream already exists") } @@ -442,7 +442,7 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } } -func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { +func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { log := s.log.WithFields(logrus.Fields{ "method": "flushMessages", "chat_id": chatId.String(), @@ -477,7 +477,7 @@ func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *c } } -func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { +func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { log := s.log.WithFields(logrus.Fields{ "method": "flushPointers", "chat_id": chatId.String(), @@ -520,7 +520,7 @@ func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream * } } -func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { +func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { log := s.log.WithField("method", "StartChat") log = client.InjectLoggingMetadata(ctx, log) @@ -695,7 +695,7 @@ func (s *server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* } } -func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { +func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) { log := s.log.WithField("method", "SendMessage") log = client.InjectLoggingMetadata(ctx, log) @@ -784,8 +784,8 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest }, nil } -// TODO(api): This likely needs an RPC that can be called from any other server. -func (s *server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) error { +// TODO(api): This likely needs an RPC that can be called from any other Server. +func (s *Server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) error { members, err := s.data.GetAllChatMembersV2(ctx, chatID) if errors.Is(err, chat.ErrMemberNotFound) { return nil @@ -827,7 +827,7 @@ func (s *server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, messa } // todo: This belongs in the common chat utility, which currently only operates on v1 chats -func (s *server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { +func (s *Server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { if err := protoChatMessage.Validate(); err != nil { return errors.Wrap(err, "proto chat message failed validation") } @@ -877,7 +877,7 @@ func (s *server) persistChatMessage(ctx context.Context, chatId chat.ChatId, pro return nil } -func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { +func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { log := s.log.WithField("method", "AdvancePointer") log = client.InjectLoggingMetadata(ctx, log) @@ -981,7 +981,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR }, nil } -func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityRequest) (*chatpb.RevealIdentityResponse, error) { +func (s *Server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityRequest) (*chatpb.RevealIdentityResponse, error) { log := s.log.WithField("method", "RevealIdentity") log = client.InjectLoggingMetadata(ctx, log) @@ -1135,7 +1135,7 @@ func (s *server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR } } -func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { +func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { log := s.log.WithField("method", "SetMuteState") log = client.InjectLoggingMetadata(ctx, log) @@ -1200,7 +1200,7 @@ func (s *server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque }, nil } -func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { +func (s *Server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { log := s.log.WithField("method", "SetSubscriptionState") log = client.InjectLoggingMetadata(ctx, log) @@ -1265,7 +1265,7 @@ func (s *server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscr }, nil } -func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { +func (s *Server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { protoChat := &chatpb.ChatMetadata{ ChatId: chatRecord.ChatId.ToProto(), Type: chatRecord.ChatType.ToProto(), @@ -1349,7 +1349,7 @@ func (s *server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m return protoChat, nil } -func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { +func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { messageRecords, err := s.data.GetAllChatMessagesV2( ctx, chatId, @@ -1405,7 +1405,7 @@ func (s *server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o return res, nil } -func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { +func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { event := &chatpb.ChatStreamEvent{ Type: &chatpb.ChatStreamEvent_Message{ Message: chatMessage, @@ -1418,7 +1418,7 @@ func (s *server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, cha // todo: send the push } -func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { +func (s *Server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { identities := map[chat.Platform]string{ chat.PlatformCode: owner.PublicKey().ToBase58(), } @@ -1434,7 +1434,7 @@ func (s *server) getAllIdentities(ctx context.Context, owner *common.Account) (m return identities, nil } -func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { +func (s *Server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) switch err { case nil: @@ -1447,7 +1447,7 @@ func (s *server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.Ch return s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) } -func (s *server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatId, memberRecord *chat.MemberRecord, owner *common.Account) (bool, error) { +func (s *Server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatId, memberRecord *chat.MemberRecord, owner *common.Account) (bool, error) { switch memberRecord.Platform { case chat.PlatformCode: return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil @@ -1459,7 +1459,7 @@ func (s *server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatI } // todo: This logic should live elsewhere in somewhere more common -func (s *server) ownsTwitterUsername(ctx context.Context, owner *common.Account, username string) (bool, error) { +func (s *Server) ownsTwitterUsername(ctx context.Context, owner *common.Account, username string) (bool, error) { ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) if err != nil { return false, errors.Wrap(err, "error deriving twitter tip address") @@ -1478,7 +1478,7 @@ func (s *server) ownsTwitterUsername(ctx context.Context, owner *common.Account, } // todo: This logic should live elsewhere in somewhere more common -func (s *server) getOwnedTwitterUsername(ctx context.Context, owner *common.Account) (string, bool, error) { +func (s *Server) getOwnedTwitterUsername(ctx context.Context, owner *common.Account) (string, bool, error) { ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) if err != nil { return "", false, errors.Wrap(err, "error deriving twitter tip address") diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go index 3d39428d..a5fa2017 100644 --- a/pkg/code/server/grpc/chat/v2/stream.go +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -98,7 +98,7 @@ type chatEventNotification struct { ts time.Time } -func (s *server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEvent) error { +func (s *Server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEvent) error { m := proto.Clone(event).(*chatpb.ChatStreamEvent) ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, m, time.Now()}) if !ok { @@ -107,7 +107,7 @@ func (s *server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEven return nil } -func (s *server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { +func (s *Server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { log := s.log.WithFields(logrus.Fields{ "method": "asyncChatEventStreamNotifier", "worker": workerId, From 386695c1bce6964f37f77c812f46783d2a01a5f3 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 8 Jul 2024 17:13:12 -0400 Subject: [PATCH 55/71] chat/v2: fix Notify message signature --- pkg/code/server/grpc/chat/v2/server.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 3efc775f..9dea3314 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -785,12 +785,16 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } // TODO(api): This likely needs an RPC that can be called from any other Server. -func (s *Server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) error { +func (s *Server) NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) { members, err := s.data.GetAllChatMembersV2(ctx, chatID) if errors.Is(err, chat.ErrMemberNotFound) { - return nil + return } else if err != nil { - return err + s.log.WithError(err). + WithField("chat_id", chatID.String()). + Warn("failed to get members for chat notification") + + return } event := &chatpb.ChatStreamEvent{ @@ -822,8 +826,6 @@ func (s *Server) NotifyNewMessage(ctx context.Context, chatID chat.ChatId, messa } _ = eg.Wait() - - return nil } // todo: This belongs in the common chat utility, which currently only operates on v1 chats From 47a49a6db2443b5655b5a7b4047b83eb168d75d3 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 10 Jul 2024 13:14:25 -0400 Subject: [PATCH 56/71] chat: add debug logging --- pkg/code/chat/sender.go | 14 +++++++++++++- pkg/code/server/grpc/chat/v2/server.go | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 571d8a6e..0ab28761 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -8,6 +8,7 @@ import ( "time" "github.com/mr-tron/base58" + "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" @@ -132,6 +133,8 @@ func SendNotificationChatMessageV2( intentId string, isSilentMessage bool, ) (canPushMessage bool, err error) { + log := logrus.StandardLogger().WithField("type", "sendNotificationChatMessageV2") + chatId := chat_v2.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) if protoMessage.Cursor != nil { @@ -182,9 +185,10 @@ func SendNotificationChatMessageV2( return fmt.Errorf("failed to initialize chat: %w", err) } + memberId := chat_v2.GenerateMemberId() err = data.PutChatMemberV2(ctx, &chat_v2.MemberRecord{ ChatId: chatId, - MemberId: chat_v2.GenerateMemberId(), + MemberId: memberId, Platform: chat_v2.PlatformCode, PlatformId: receiver.PublicKey().ToBase58(), JoinedAt: time.Now(), @@ -193,6 +197,12 @@ func SendNotificationChatMessageV2( return fmt.Errorf("failed to initialize chat with member: %w", err) } + log.WithFields(logrus.Fields{ + "chat_id": chatId.String(), + "member": memberId.String(), + "platform_id": receiver.PublicKey().ToBase58(), + }).Info("Initialized chat for tip") + return nil }) if err != nil { @@ -241,6 +251,8 @@ func SendNotificationChatMessageV2( } notifier.NotifyMessage(ctx, chatId, protoMessage) + + log.WithField("chat_id", chatId.String()).Info("Put and notified") } // TODO: Once we move more things over to chatv2, we will need to increment diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 9dea3314..d0bc6d19 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -133,7 +133,7 @@ func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch // todo: Use a better query that returns chat IDs. This will result in duplicate // chat results if the user is in the chat multiple times across many identities. - patformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( + platformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( ctx, myIdentities, query.WithCursor(cursor), @@ -149,8 +149,10 @@ func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch return nil, status.Error(codes.Internal, "") } + log.WithField("chats", len(platformUserMemberRecords)).Info("Retrieved chatlist for user") + var protoChats []*chatpb.ChatMetadata - for _, platformUserMemberRecord := range patformUserMemberRecords { + for _, platformUserMemberRecord := range platformUserMemberRecords { log := log.WithField("chat_id", platformUserMemberRecord.ChatId.String()) chatRecord, err := s.data.GetChatByIdV2(ctx, platformUserMemberRecord.ChatId) @@ -786,8 +788,14 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest // TODO(api): This likely needs an RPC that can be called from any other Server. func (s *Server) NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) { + log := s.log.WithFields(logrus.Fields{ + "chat_id": chatID.String(), + "messge_id": message.MessageId.String(), + }) + members, err := s.data.GetAllChatMembersV2(ctx, chatID) if errors.Is(err, chat.ErrMemberNotFound) { + log.Info("Dropping message notification, no members") return } else if err != nil { s.log.WithError(err). @@ -805,6 +813,8 @@ func (s *Server) NotifyMessage(ctx context.Context, chatID chat.ChatId, message eg.SetLimit(min(32, len(members))) for _, m := range members { + m := m + eg.Go(func() error { streamKey := fmt.Sprintf("%s:%s", chatID, m.MemberId.String()) s.streamsMu.RLock() @@ -815,6 +825,7 @@ func (s *Server) NotifyMessage(ctx context.Context, chatID chat.ChatId, message return nil } + log.WithField("member_id", m.MemberId.String()).Info("Notifying member stream") if err = stream.notify(event, time.Second); err != nil { s.log.WithError(err). WithField("member", m.MemberId.String()). From da3481566e372ec2c26c70306464a9f4249d3db5 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 15 Jul 2024 10:20:43 -0400 Subject: [PATCH 57/71] chat: don't execute v2 member add in tx. Turns out we're being called in a transaction up in the chain. There should be a better way to detect (or noop) this behaviour Turns out we're being called in a transaction up in the chain. There should be a better way to detect (or noop) this behaviour --- pkg/code/chat/sender.go | 50 +++++++++++++++++++---------------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 0ab28761..b018fd73 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -2,7 +2,6 @@ package chat import ( "context" - "database/sql" "errors" "fmt" "time" @@ -179,35 +178,32 @@ func SendNotificationChatMessageV2( CreatedAt: time.Now(), } - err = data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { - err = data.PutChatV2(ctx, chatRecord) - if err != nil && !errors.Is(err, chat_v2.ErrChatExists) { - return fmt.Errorf("failed to initialize chat: %w", err) - } - - memberId := chat_v2.GenerateMemberId() - err = data.PutChatMemberV2(ctx, &chat_v2.MemberRecord{ - ChatId: chatId, - MemberId: memberId, - Platform: chat_v2.PlatformCode, - PlatformId: receiver.PublicKey().ToBase58(), - JoinedAt: time.Now(), - }) - if err != nil { - return fmt.Errorf("failed to initialize chat with member: %w", err) - } - - log.WithFields(logrus.Fields{ - "chat_id": chatId.String(), - "member": memberId.String(), - "platform_id": receiver.PublicKey().ToBase58(), - }).Info("Initialized chat for tip") - - return nil + // TODO: These should be run in a transaction, but so far + // we're being called in a transaction. We should have some kind + // of safety check here... + err = data.PutChatV2(ctx, chatRecord) + if err != nil && !errors.Is(err, chat_v2.ErrChatExists) { + return false, fmt.Errorf("failed to initialize chat: %w", err) + } + + memberId := chat_v2.GenerateMemberId() + err = data.PutChatMemberV2(ctx, &chat_v2.MemberRecord{ + ChatId: chatId, + MemberId: memberId, + Platform: chat_v2.PlatformCode, + PlatformId: receiver.PublicKey().ToBase58(), + JoinedAt: time.Now(), }) if err != nil { - return false, err + return false, fmt.Errorf("failed to initialize chat with member: %w", err) } + + log.WithFields(logrus.Fields{ + "chat_id": chatId.String(), + "member": memberId.String(), + "platform_id": receiver.PublicKey().ToBase58(), + }).Info("Initialized chat for tip") + } else if err != nil { return false, err } From ca8fcf46df6d0b1ea476154f2ae5b917de4bad05 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 15 Jul 2024 10:43:38 -0400 Subject: [PATCH 58/71] chat/v2: add handling for notifiaction chats --- pkg/code/server/grpc/chat/v2/server.go | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index d0bc6d19..074caa37 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1291,6 +1291,15 @@ func (s *Server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m protoChat.CanMute = true protoChat.CanUnsubscribe = true + case chat.ChatTypeNotification: + if chatRecord.ChatTitle == nil { + // TODO: we shouldn't fail the whole RPC + return nil, fmt.Errorf("invalid notification chat: missing title") + } + + // TODO: Localization + protoChat.Title = *chatRecord.ChatTitle + default: return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) } From 91956a6465faa0293f494e825e3250e70369db48 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Sat, 20 Jul 2024 22:24:18 -0400 Subject: [PATCH 59/71] chat/v2: send two way chat notifications. Also refactor a bit of the notifier system to not cause a circular dependency. --- pkg/code/chat/chat.go | 6 + pkg/code/chat/message_tips.go | 3 +- .../{server/grpc/chat/v2 => chat}/notifier.go | 2 +- pkg/code/chat/sender.go | 3 +- pkg/code/data/chat/v2/model.go | 22 ++- pkg/code/localization/keys.go | 1 + pkg/code/push/notifications.go | 145 ++++++++++++++++++ pkg/code/server/grpc/chat/v2/server.go | 118 ++++++++------ pkg/code/server/grpc/transaction/v2/server.go | 6 +- .../server/grpc/transaction/v2/testutil.go | 4 +- 10 files changed, 249 insertions(+), 61 deletions(-) rename pkg/code/{server/grpc/chat/v2 => chat}/notifier.go (96%) diff --git a/pkg/code/chat/chat.go b/pkg/code/chat/chat.go index e23f6555..2a5b20d0 100644 --- a/pkg/code/chat/chat.go +++ b/pkg/code/chat/chat.go @@ -8,6 +8,7 @@ const ( KinPurchasesName = "Kin Purchases" PaymentsName = "Payments" // Renamed to Web Payments on client TipsName = "Tips" + TwoWayChatName = "Two Way Chat" // Test chats used for unit/integration testing only TestCantMuteName = "TestCantMute" @@ -45,6 +46,11 @@ var ( CanMute: true, CanUnsubscribe: false, }, + TwoWayChatName: { + TitleLocalizationKey: localization.ChatTitleTwoWay, + CanMute: true, + CanUnsubscribe: false, + }, TestCantMuteName: { TitleLocalizationKey: "n/a", diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index 6dceb051..81a18f83 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -18,7 +18,6 @@ import ( chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" - chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" ) // SendTipsExchangeMessage sends a message to the Tips chat with exchange data @@ -26,7 +25,7 @@ import ( // Tips chat will be ignored. // // Note: Tests covered in SubmitIntent history tests -func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notifier chat_server.Notifier, intentRecord *intent.Record) ([]*MessageWithOwner, error) { +func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notifier Notifier, intentRecord *intent.Record) ([]*MessageWithOwner, error) { intentIdRaw, err := base58.Decode(intentRecord.IntentId) if err != nil { return nil, fmt.Errorf("invalid intent id: %w", err) diff --git a/pkg/code/server/grpc/chat/v2/notifier.go b/pkg/code/chat/notifier.go similarity index 96% rename from pkg/code/server/grpc/chat/v2/notifier.go rename to pkg/code/chat/notifier.go index 994d95e9..15839aa2 100644 --- a/pkg/code/server/grpc/chat/v2/notifier.go +++ b/pkg/code/chat/notifier.go @@ -1,4 +1,4 @@ -package chat_v2 +package chat import ( "context" diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index b018fd73..4b56f31c 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -17,7 +17,6 @@ import ( code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" - chatserver "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" ) // SendNotificationChatMessageV1 sends a chat message to a receiving owner account. @@ -124,7 +123,7 @@ func SendNotificationChatMessageV1( func SendNotificationChatMessageV2( ctx context.Context, data code_data.Provider, - notifier chatserver.Notifier, + notifier Notifier, chatTitle string, isVerifiedChat bool, receiver *common.Account, diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index ef3c7071..bae36478 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -70,6 +70,11 @@ type MemberRecord struct { Platform Platform PlatformId string + // If Platform != PlatformCode, this store the owner + // of the account (at time of creation). This allows + // us to send push notifications for non-code users. + OwnerAccount string + DeliveryPointer *MessageId ReadPointer *MessageId @@ -79,6 +84,14 @@ type MemberRecord struct { JoinedAt time.Time } +func (m *MemberRecord) GetOwner() string { + if m.Platform == PlatformCode { + return m.PlatformId + } + + return m.OwnerAccount +} + type MessageRecord struct { Id int64 ChatId ChatId @@ -292,6 +305,9 @@ func (r *MemberRecord) Validate() error { if len(r.PlatformId) == 0 { return errors.New("platform id is required") } + if r.Platform != PlatformCode && len(r.OwnerAccount) == 0 { + return errors.New("owner account is required for non code platform members") + } switch r.Platform { case PlatformCode: @@ -349,8 +365,9 @@ func (r *MemberRecord) Clone() MemberRecord { ChatId: r.ChatId, MemberId: r.MemberId, - Platform: r.Platform, - PlatformId: r.PlatformId, + Platform: r.Platform, + PlatformId: r.PlatformId, + OwnerAccount: r.OwnerAccount, DeliveryPointer: deliveryPointerCopy, ReadPointer: readPointerCopy, @@ -370,6 +387,7 @@ func (r *MemberRecord) CopyTo(dst *MemberRecord) { dst.Platform = r.Platform dst.PlatformId = r.PlatformId + dst.OwnerAccount = r.OwnerAccount if r.DeliveryPointer != nil { cloned := r.DeliveryPointer.Clone() diff --git a/pkg/code/localization/keys.go b/pkg/code/localization/keys.go index edd557df..926acc64 100644 --- a/pkg/code/localization/keys.go +++ b/pkg/code/localization/keys.go @@ -43,6 +43,7 @@ const ( ChatTitleKinPurchases = "title.chat.kinPurchases" ChatTitlePayments = "title.chat.payments" ChatTitleTips = "title.chat.tips" + ChatTitleTwoWay = "title.chat.twoWay" // Message Bodies diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index bfe4db7c..b1a5c9b2 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -9,6 +9,7 @@ import ( "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" + chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" chat_util "github.com/code-payments/code-server/pkg/code/chat" @@ -412,3 +413,147 @@ func SendChatMessagePushNotification( } return nil } + +func SendChatMessagePushNotificationV2( + ctx context.Context, + data code_data.Provider, + pusher push_lib.Provider, + chatTitle string, + owner *common.Account, + chatMessage *chatv2pb.ChatMessage, +) error { + log := logrus.StandardLogger().WithFields(logrus.Fields{ + "method": "SendChatMessagePushNotificationV2", + "owner": owner.PublicKey().ToBase58(), + "chat": chatTitle, + }) + + // Best-effort try to update the badge count before pushing message content + // + // Note: Only chat messages generate badge counts + err := UpdateBadgeCount(ctx, data, pusher, owner) + if err != nil { + log.WithError(err).Warn("failure updating badge count on device") + } + + locale, err := data.GetUserLocale(ctx, owner.PublicKey().ToBase58()) + if err != nil { + log.WithError(err).Warn("failure getting user locale") + return err + } + + var localizedPushTitle string + + chatProperties, ok := chat_util.InternalChatProperties[chatTitle] + if ok { + localized, err := localization.Localize(locale, chatProperties.TitleLocalizationKey) + if err != nil { + return nil + } + localizedPushTitle = localized + } else { + domainDisplayName, err := thirdparty.GetDomainDisplayName(chatTitle) + if err == nil { + localizedPushTitle = domainDisplayName + } else { + return nil + } + } + + var anyErrorPushingContent bool + for _, content := range chatMessage.Content { + var contentToPush *chatv2pb.Content + switch typedContent := content.Type.(type) { + case *chatv2pb.Content_Localized: + localizedPushBody, err := localization.Localize(locale, typedContent.Localized.KeyOrText) + if err != nil { + continue + } + + contentToPush = &chatv2pb.Content{ + Type: &chatv2pb.Content_Localized{ + Localized: &chatv2pb.LocalizedContent{ + KeyOrText: localizedPushBody, + }, + }, + } + case *chatv2pb.Content_ExchangeData: + var currencyCode currency_lib.Code + var nativeAmount float64 + if typedContent.ExchangeData.GetExact() != nil { + exchangeData := typedContent.ExchangeData.GetExact() + currencyCode = currency_lib.Code(exchangeData.Currency) + nativeAmount = exchangeData.NativeAmount + } else { + exchangeData := typedContent.ExchangeData.GetPartial() + currencyCode = currency_lib.Code(exchangeData.Currency) + nativeAmount = exchangeData.NativeAmount + } + + localizedPushBody, err := localization.LocalizeFiatWithVerb( + locale, + chatpb.ExchangeDataContent_Verb(typedContent.ExchangeData.Verb), + currencyCode, + nativeAmount, + true, + ) + if err != nil { + continue + } + + contentToPush = &chatv2pb.Content{ + Type: &chatv2pb.Content_Localized{ + Localized: &chatv2pb.LocalizedContent{ + KeyOrText: localizedPushBody, + }, + }, + } + case *chatv2pb.Content_NaclBox, *chatv2pb.Content_Text: + contentToPush = content + case *chatv2pb.Content_ThankYou: + contentToPush = &chatv2pb.Content{ + Type: &chatv2pb.Content_Localized{ + Localized: &chatv2pb.LocalizedContent{ + // todo: localize this + KeyOrText: "🙏 They thanked you for their tip", + }, + }, + } + } + + if contentToPush == nil { + continue + } + + marshalledContent, err := proto.Marshal(contentToPush) + if err != nil { + log.WithError(err).Warn("failure marshalling chat content") + return err + } + + kvs := map[string]string{ + "chat_title": localizedPushTitle, + "message_content": base64.StdEncoding.EncodeToString(marshalledContent), + } + + err = sendMutableNotificationToOwner( + ctx, + data, + pusher, + owner, + chatMessageDataPush, + chatTitle, + kvs, + ) + if err != nil { + anyErrorPushingContent = true + log.WithError(err).Warn("failure sending data push notification") + } + } + + if anyErrorPushingContent { + return errors.New("at least one piece of content failed to push") + } + + return nil +} diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 074caa37..785bd074 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -25,14 +25,17 @@ import ( transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" auth_util "github.com/code-payments/code-server/pkg/code/auth" + chatv2 "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/data/twitter" "github.com/code-payments/code-server/pkg/code/localization" + push_util "github.com/code-payments/code-server/pkg/code/push" "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" + "github.com/code-payments/code-server/pkg/push" timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" sync_util "github.com/code-payments/code-server/pkg/sync" ) @@ -50,6 +53,7 @@ type Server struct { data code_data.Provider auth *auth_util.RPCSignatureVerifier + push push.Provider streamsMu sync.RWMutex streams map[string]*chatEventStream @@ -60,12 +64,17 @@ type Server struct { chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) *Server { +func NewChatServer( + data code_data.Provider, + auth *auth_util.RPCSignatureVerifier, + push push.Provider, +) *Server { s := &Server{ log: logrus.StandardLogger().WithField("type", "chat/v2/Server"), data: data, auth: auth, + push: push, streams: make(map[string]*chatEventStream), @@ -631,8 +640,9 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* ChatId: chatId, MemberId: chat.GenerateMemberId(), - Platform: chat.PlatformTwitter, - PlatformId: twitterUsername, + Platform: chat.PlatformTwitter, + PlatformId: twitterUsername, + OwnerAccount: owner.PublicKey().ToBase58(), JoinedAt: creationTs, }, @@ -748,8 +758,10 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, status.Error(codes.Internal, "") } + var chatTitle string switch chatRecord.ChatType { case chat.ChatTypeTwoWay: + chatTitle = chatv2.TwoWayChatName default: return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_INVALID_CHAT_TYPE, @@ -779,6 +791,7 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } s.onPersistChatMessage(log, chatId, chatMessage) + s.sendPushNotifications(chatId, chatTitle, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, @@ -787,56 +800,13 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } // TODO(api): This likely needs an RPC that can be called from any other Server. -func (s *Server) NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) { +func (s *Server) NotifyMessage(_ context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) { log := s.log.WithFields(logrus.Fields{ "chat_id": chatID.String(), "messge_id": message.MessageId.String(), }) - members, err := s.data.GetAllChatMembersV2(ctx, chatID) - if errors.Is(err, chat.ErrMemberNotFound) { - log.Info("Dropping message notification, no members") - return - } else if err != nil { - s.log.WithError(err). - WithField("chat_id", chatID.String()). - Warn("failed to get members for chat notification") - - return - } - - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{Message: message}, - } - - var eg errgroup.Group - eg.SetLimit(min(32, len(members))) - - for _, m := range members { - m := m - - eg.Go(func() error { - streamKey := fmt.Sprintf("%s:%s", chatID, m.MemberId.String()) - s.streamsMu.RLock() - stream := s.streams[streamKey] - s.streamsMu.RUnlock() - - if stream == nil { - return nil - } - - log.WithField("member_id", m.MemberId.String()).Info("Notifying member stream") - if err = stream.notify(event, time.Second); err != nil { - s.log.WithError(err). - WithField("member", m.MemberId.String()). - Info("Failed to notify chat stream") - } - - return nil - }) - } - - _ = eg.Wait() + s.onPersistChatMessage(log, chatID, message) } // todo: This belongs in the common chat utility, which currently only operates on v1 chats @@ -1433,11 +1403,61 @@ func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, cha Message: chatMessage, }, } + if err := s.asyncNotifyAll(chatId, event); err != nil { log.WithError(err).Warn("failure notifying chat event") } +} + +func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, message *chatpb.ChatMessage) { + log := s.log.WithFields(logrus.Fields{ + "method": "sendPushNotifications", + "chat_id": chatId.String(), + }) + + // todo: err group + members, err := s.data.GetAllChatMembersV2(context.Background(), chatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + var eg errgroup.Group + eg.SetLimit(min(32, len(members))) + + for _, m := range members { + if m.IsMuted || m.IsUnsubscribed { + continue + } - // todo: send the push + owner, err := common.NewAccountFromPublicKeyString(m.GetOwner()) + if err != nil { + log.WithError(err).WithField("member", m.MemberId.String()).Warn("failure getting owner") + continue + } + + m := m + eg.Go(func() error { + err = push_util.SendChatMessagePushNotificationV2( + context.Background(), + s.data, + s.push, + chatTitle, + owner, + message, + ) + if err != nil { + log. + WithError(err). + WithField("member", m.MemberId). + Warn("failure sending push notification") + } + + return nil + }) + } + + _ = eg.Wait() } func (s *Server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { diff --git a/pkg/code/server/grpc/transaction/v2/server.go b/pkg/code/server/grpc/transaction/v2/server.go index 2bbbb551..5ba0bee0 100644 --- a/pkg/code/server/grpc/transaction/v2/server.go +++ b/pkg/code/server/grpc/transaction/v2/server.go @@ -11,10 +11,10 @@ import ( "github.com/code-payments/code-server/pkg/code/antispam" auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/lawenforcement" - chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" "github.com/code-payments/code-server/pkg/code/server/grpc/messaging" "github.com/code-payments/code-server/pkg/jupiter" "github.com/code-payments/code-server/pkg/kin" @@ -31,7 +31,7 @@ type transactionServer struct { auth *auth_util.RPCSignatureVerifier pusher push_lib.Provider - notifier chat_server.Notifier + notifier chat.Notifier jupiterClient *jupiter.Client @@ -67,7 +67,7 @@ type transactionServer struct { func NewTransactionServer( data code_data.Provider, pusher push_lib.Provider, - notifier chat_server.Notifier, + notifier chat.Notifier, jupiterClient *jupiter.Client, messagingClient messaging.InternalMessageClient, maxmind *maxminddb.Reader, diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index 3941af36..a160c2db 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -31,6 +31,7 @@ import ( transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" "github.com/code-payments/code-server/pkg/code/antispam" + "github.com/code-payments/code-server/pkg/code/chat" chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" @@ -57,7 +58,6 @@ import ( user_identity "github.com/code-payments/code-server/pkg/code/data/user/identity" "github.com/code-payments/code-server/pkg/code/data/vault" exchange_rate_util "github.com/code-payments/code-server/pkg/code/exchangerate" - chat_server "github.com/code-payments/code-server/pkg/code/server/grpc/chat/v2" "github.com/code-payments/code-server/pkg/code/server/grpc/messaging" transaction_util "github.com/code-payments/code-server/pkg/code/transaction" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -187,7 +187,7 @@ func setupTestEnv(t *testing.T, serverOverrides *testOverrides) (serverTestEnv, testService := NewTransactionServer( db, memory_push.NewPushProvider(), - chat_server.NewNoopNotifier(), + chat.NewNoopNotifier(), nil, messaging.NewMessagingClient(db), nil, From e78b0281322fb5682b48a2f650eb785ea00e6c40 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 22 Jul 2024 13:27:37 -0400 Subject: [PATCH 60/71] chat/v2: don't send push to sender --- pkg/code/server/grpc/chat/v2/server.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 785bd074..80f299dc 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -791,7 +791,7 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } s.onPersistChatMessage(log, chatId, chatMessage) - s.sendPushNotifications(chatId, chatTitle, chatMessage) + s.sendPushNotifications(chatId, chatTitle, memberId, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, @@ -1409,7 +1409,7 @@ func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, cha } } -func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, message *chatpb.ChatMessage) { +func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sender chat.MemberId, message *chatpb.ChatMessage) { log := s.log.WithFields(logrus.Fields{ "method": "sendPushNotifications", "chat_id": chatId.String(), @@ -1426,7 +1426,7 @@ func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, mes eg.SetLimit(min(32, len(members))) for _, m := range members { - if m.IsMuted || m.IsUnsubscribed { + if m.MemberId == sender || m.IsMuted || m.IsUnsubscribed { continue } From f75156927c8ac9ddfbf13a325f6368e844a5d3fd Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 24 Jul 2024 18:15:01 -0400 Subject: [PATCH 61/71] chat/v2: add chat_id to push header, debug logs --- pkg/code/push/notifications.go | 3 +++ pkg/code/server/grpc/chat/v2/server.go | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index b1a5c9b2..5bafd6c1 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -16,6 +16,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" + chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/localization" "github.com/code-payments/code-server/pkg/code/thirdparty" currency_lib "github.com/code-payments/code-server/pkg/currency" @@ -418,6 +419,7 @@ func SendChatMessagePushNotificationV2( ctx context.Context, data code_data.Provider, pusher push_lib.Provider, + chatId chat_v2.ChatId, chatTitle string, owner *common.Account, chatMessage *chatv2pb.ChatMessage, @@ -533,6 +535,7 @@ func SendChatMessagePushNotificationV2( kvs := map[string]string{ "chat_title": localizedPushTitle, + "chat_id": chatId.String(), "message_content": base64.StdEncoding.EncodeToString(marshalledContent), } diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 80f299dc..4b719698 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1412,10 +1412,11 @@ func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, cha func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sender chat.MemberId, message *chatpb.ChatMessage) { log := s.log.WithFields(logrus.Fields{ "method": "sendPushNotifications", + "sender": sender.String(), "chat_id": chatId.String(), }) - // todo: err group + // TODO: Callers might already have this loaded. members, err := s.data.GetAllChatMembersV2(context.Background(), chatId) if err != nil { log.WithError(err).Warn("failure getting chat members") @@ -1438,10 +1439,12 @@ func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sen m := m eg.Go(func() error { + log.WithField("member", m.MemberId.String()).Info("sending push notification") err = push_util.SendChatMessagePushNotificationV2( context.Background(), s.data, s.push, + chatId, chatTitle, owner, message, From 0dd205afc9f82b332e6926ffa33f509230c5f490 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 4 Sep 2024 11:12:19 -0400 Subject: [PATCH 62/71] chat/v2: remove tip hooks, sketch out latest rpcs --- go.mod | 2 +- pkg/code/chat/message_tips.go | 54 +--- pkg/code/data/chat/v2/memory/store.go | 13 +- pkg/code/data/chat/v2/store.go | 24 +- pkg/code/data/internal.go | 8 +- pkg/code/server/grpc/chat/v1/server.go | 110 +++++--- pkg/code/server/grpc/chat/v2/server.go | 241 ++++++++---------- .../grpc/transaction/v2/history_test.go | 56 ---- 8 files changed, 226 insertions(+), 282 deletions(-) diff --git a/go.mod b/go.mod index 4e808287..cc45d8ea 100644 --- a/go.mod +++ b/go.mod @@ -42,6 +42,7 @@ require ( golang.org/x/crypto v0.21.0 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/net v0.22.0 + golang.org/x/sync v0.7.0 golang.org/x/text v0.14.0 golang.org/x/time v0.5.0 google.golang.org/api v0.170.0 @@ -120,7 +121,6 @@ require ( go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.18.0 // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/appengine/v2 v2.0.1 // indirect diff --git a/pkg/code/chat/message_tips.go b/pkg/code/chat/message_tips.go index 81a18f83..5500bc89 100644 --- a/pkg/code/chat/message_tips.go +++ b/pkg/code/chat/message_tips.go @@ -2,22 +2,12 @@ package chat import ( "context" - "fmt" - - "github.com/mr-tron/base58" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" - "google.golang.org/protobuf/types/known/timestamppb" - chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" - chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" - commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" - "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" - chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/pkg/errors" ) // SendTipsExchangeMessage sends a message to the Tips chat with exchange data @@ -26,11 +16,6 @@ import ( // // Note: Tests covered in SubmitIntent history tests func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notifier Notifier, intentRecord *intent.Record) ([]*MessageWithOwner, error) { - intentIdRaw, err := base58.Decode(intentRecord.IntentId) - if err != nil { - return nil, fmt.Errorf("invalid intent id: %w", err) - } - messageId := intentRecord.IntentId exchangeData, ok := getExchangeDataFromIntent(intentRecord) @@ -78,26 +63,6 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notif return nil, errors.Wrap(err, "error creating proto chat message") } - v2Message := &chatv2pb.ChatMessage{ - MessageId: chat_v2.GenerateMessageId().ToProto(), - Content: []*chatv2pb.Content{ - { - Type: &chatv2pb.Content_ExchangeData{ - ExchangeData: &chatv2pb.ExchangeDataContent{ - Verb: chatv2pb.ExchangeDataContent_Verb(verb), - ExchangeData: &chatv2pb.ExchangeDataContent_Exact{ - Exact: exchangeData, - }, - Reference: &chatv2pb.ExchangeDataContent_Intent{ - Intent: &commonpb.IntentId{Value: intentIdRaw}, - }, - }, - }, - }, - }, - Ts: timestamppb.New(intentRecord.CreatedAt), - } - canPush, err := SendNotificationChatMessageV1( ctx, data, @@ -112,23 +77,6 @@ func SendTipsExchangeMessage(ctx context.Context, data code_data.Provider, notif return nil, errors.Wrap(err, "error persisting v1 chat message") } - _, err = SendNotificationChatMessageV2( - ctx, - data, - notifier, - TipsName, - true, - receiver, - v2Message, - intentRecord.IntentId, - verb != chatpb.ExchangeDataContent_RECEIVED_TIP, - ) - if err != nil { - // TODO: Eventually we'll want to return an error, but for now we'll log - // since we're not in 'prod' yet. - logrus.StandardLogger().WithError(err).Warn("Failed to send notification message (v2)") - } - if canPush { messagesToPush = append(messagesToPush, &MessageWithOwner{ Owner: receiver, diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index 82ac8755..ff54f9ac 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -100,12 +100,13 @@ func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[c } // GetUnreadCount implements chat.store.GetUnreadCount -func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, readPointer chat.MessageId) (uint32, error) { +func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) { s.mu.Lock() defer s.mu.Unlock() items := s.findMessagesByChatId(chatId) items = s.filterMessagesAfter(items, readPointer) + items = s.filterMessagesNotSentBy(items, memberId) items = s.filterNotifiedMessages(items) return uint32(len(items)), nil } @@ -446,6 +447,16 @@ func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.Me return res } +func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord { + var res []*chat.MessageRecord + for _, item := range items { + if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) { + res = append(res, item) + } + } + return res +} + func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { var res []*chat.MessageRecord for _, item := range items { diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 4aa3e1b7..2029f6db 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -4,6 +4,8 @@ import ( "context" "errors" + "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + "github.com/code-payments/code-server/pkg/database/query" ) @@ -43,7 +45,7 @@ type Store interface { GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) // GetUnreadCount gets the unread message count for a chat ID at a read pointer - GetUnreadCount(ctx context.Context, chatId ChatId, readPointer MessageId) (uint32, error) + GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error) // PutChat creates a new chat PutChat(ctx context.Context, record *ChatRecord) error @@ -66,3 +68,23 @@ type Store interface { // SetSubscriptionState updates the subscription state for a chat member SetSubscriptionState(ctx context.Context, chatId ChatId, memberId MemberId, isSubscribed bool) error } + +type PaymentStore interface { + // MarkFriendshipPaid marks a friendship as paid. + // + // The intentId is the intent that paid for the friendship. + MarkFriendshipPaid(ctx context.Context, payer, other *common.SolanaAccountId, intentId *common.IntentId) error + + // IsFriendshipPaid returns whether a payment has been made for a friendship. + // + // IsFriendshipPaid is reflexive, with only a single payment being required. + IsFriendshipPaid(ctx context.Context, user, other *common.SolanaAccountId) (bool, error) + + // MarkChatPaid marks a chat as paid. + MarkChatPaid(ctx context.Context, payer *common.SolanaAccountId, chat ChatId) error + + // IsChatPaid returns whether a member paid to be part of a chat. + // + // This is only valid for non-two way chats. + IsChatPaid(ctx context.Context, chatId ChatId, member *common.SolanaAccountId) (bool, error) +} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index a0fcc965..342bc3c6 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -71,7 +71,7 @@ import ( intent_memory_client "github.com/code-payments/code-server/pkg/code/data/intent/memory" login_memory_client "github.com/code-payments/code-server/pkg/code/data/login/memory" merkletree_memory_client "github.com/code-payments/code-server/pkg/code/data/merkletree/memory" - messaging "github.com/code-payments/code-server/pkg/code/data/messaging" + "github.com/code-payments/code-server/pkg/code/data/messaging" messaging_memory_client "github.com/code-payments/code-server/pkg/code/data/messaging/memory" nonce_memory_client "github.com/code-payments/code-server/pkg/code/data/nonce/memory" onramp_memory_client "github.com/code-payments/code-server/pkg/code/data/onramp/memory" @@ -402,7 +402,7 @@ type DatabaseData interface { GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) - GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error @@ -1493,8 +1493,8 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha } return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) } -func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) { - return dp.chatv2.GetUnreadCount(ctx, chatId, readPointer) +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer) } func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { return dp.chatv2.PutChat(ctx, record) diff --git a/pkg/code/server/grpc/chat/v1/server.go b/pkg/code/server/grpc/chat/v1/server.go index 49859416..5fef09f9 100644 --- a/pkg/code/server/grpc/chat/v1/server.go +++ b/pkg/code/server/grpc/chat/v1/server.go @@ -54,18 +54,18 @@ type server struct { chatLocks *sync_util.StripedLock chatEventChans *sync_util.StripedChannel - streamsMu sync.RWMutex - streams map[string]*chatEventStream - chatpb.UnimplementedChatServer } -func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier) chatpb.ChatServer { - return &server{ - log: logrus.StandardLogger().WithField("type", "chat/server"), - data: data, - auth: auth, - streams: make(map[string]*chatEventStream), +func NewChatServer(data code_data.Provider, auth *auth_util.RPCSignatureVerifier, pusher push_lib.Provider) chatpb.ChatServer { + s := &server{ + log: logrus.StandardLogger().WithField("type", "chat/v1/server"), + data: data, + auth: auth, + pusher: pusher, + streams: make(map[string]*chatEventStream), + chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters + chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters } for i, channel := range s.chatEventChans.GetChannels() { @@ -403,21 +403,9 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR Pointers: []*chatpb.Pointer{req.Pointer}, } - s.streamsMu.RLock() - for key, stream := range s.streams { - if !strings.HasPrefix(key, chatId.String()) { - continue - } - - if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { - continue - } - - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - } + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - s.streamsMu.RUnlock() return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_OK, @@ -428,7 +416,7 @@ func (s *server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR return nil, status.Error(codes.InvalidArgument, "Pointer.Kind must be READ") } - chatRecord, err := s.data.GetChatById(ctx, chatId) + chatRecord, err := s.data.GetChatByIdV1(ctx, chatId) if err == chat.ErrChatNotFound { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, @@ -656,6 +644,22 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e s.streamsMu.Unlock() + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another OpenMessageStream() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == stream { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + sendPingCh := time.After(0) streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) @@ -734,6 +738,10 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, status.Error(codes.InvalidArgument, "content[0] must be Text or ThankYou") } + chatLock := s.chatLocks.Get(chatId[:]) + chatLock.Lock() + defer chatLock.Unlock() + // todo: Revisit message IDs messageId, err := common.NewRandomAccount() if err != nil { @@ -749,28 +757,54 @@ func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest Cursor: nil, // todo: Don't have cursor until we save it to the DB } + // todo: Save the message to the DB + event := &chatpb.ChatStreamEvent{ Messages: []*chatpb.ChatMessage{chatMessage}, } - s.streamsMu.RLock() - for key, stream := range s.streams { - if !strings.HasPrefix(key, chatId.String()) { - continue - } - - if strings.HasSuffix(key, owner.PublicKey().ToBase58()) { - continue - } - - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - } + if err := s.asyncNotifyAll(chatId, owner, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - s.streamsMu.RUnlock() + + s.asyncPushChatMessage(owner, chatId, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, Message: chatMessage, }, nil } + +// todo: doesn't respect mute/unsubscribe rules +// todo: only sends pushes to active stream listeners instead of all message recipients +func (s *server) asyncPushChatMessage(sender *common.Account, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { + ctx := context.TODO() + + go func() { + s.streamsMu.RLock() + for key := range s.streams { + if !strings.HasPrefix(key, chatId.String()) { + continue + } + + receiver, err := common.NewAccountFromPublicKeyString(strings.Split(key, ":")[1]) + if err != nil { + continue + } + + if bytes.Equal(sender.PublicKey().ToBytes(), receiver.PublicKey().ToBytes()) { + continue + } + + go push_util.SendChatMessagePushNotification( + ctx, + s.data, + s.pusher, + "TontonTwitch", + receiver, + chatMessage, + ) + } + s.streamsMu.RUnlock() + }() +} diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 4b719698..33781317 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -2,7 +2,6 @@ package chat_v2 import ( "context" - "crypto/rand" "database/sql" "fmt" "math" @@ -22,8 +21,6 @@ import ( chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" - transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" - auth_util "github.com/code-payments/code-server/pkg/code/auth" chatv2 "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" @@ -548,160 +545,148 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* return nil, err } - switch typed := req.Parameters.(type) { - case *chatpb.StartChatRequest_TipChat: - intentId := base58.Encode(typed.TipChat.IntentId.Value) - log = log.WithField("intent", intentId) - - intentRecord, err := s.data.GetIntent(ctx, intentId) - if err == intent.ErrIntentNotFound { - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_INVALID_PARAMETER, - Chat: nil, - }, nil - } else if err != nil { - log.WithError(err).Warn("failure getting intent record") - return nil, status.Error(codes.Internal, "") - } - - // The intent was not for a tip. - if intentRecord.SendPrivatePaymentMetadata == nil || !intentRecord.SendPrivatePaymentMetadata.IsTip { - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_INVALID_PARAMETER, - Chat: nil, - }, nil - } - - tipper, err := common.NewAccountFromPublicKeyString(intentRecord.InitiatorOwnerAccount) - if err != nil { - log.WithError(err).Warn("invalid tipper owner account") - return nil, status.Error(codes.Internal, "") - } - log = log.WithField("tipper", tipper.PublicKey().ToBase58()) + // todo: Maybe expand this in the future. + if req.Self.Platform != chatpb.Platform_TWITTER { + log.Info("cannot start chat without specifying username") + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_INVALID_PARAMETER}, nil + } - tippee, err := common.NewAccountFromPublicKeyString(intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount) - if err != nil { - log.WithError(err).Warn("invalid tippee owner account") - return nil, status.Error(codes.Internal, "") - } - log = log.WithField("tippee", tippee.PublicKey().ToBase58()) - - // For now, don't allow chats where you tipped yourself. - // - // todo: How do we want to handle this case? - if owner.PublicKey().ToBase58() == tipper.PublicKey().ToBase58() { - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_INVALID_PARAMETER, - Chat: nil, - }, nil - } + selfVerified, err := s.ownsTwitterUsername(ctx, owner, req.Self.Username) + if err != nil { + log.WithError(err).Warn("failed to verify creators twitter") + return nil, status.Error(codes.Internal, "") + } + if !selfVerified { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } - // Only the owner of the platform user at the time of tipping can initiate the chat. - if owner.PublicKey().ToBase58() != tippee.PublicKey().ToBase58() { - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_DENIED, - Chat: nil, - }, nil - } + switch typed := req.Parameters.(type) { + case *chatpb.StartChatRequest_TwoWayChat: + chatId := chat.GetChatId(owner.PublicKey().ToBase58(), base58.Encode(typed.TwoWayChat.OtherUser.Value), true) - // todo: This will require a refactor when we allow creation of other types of chats - switch intentRecord.SendPrivatePaymentMetadata.TipMetadata.Platform { - case transactionpb.TippedUser_TWITTER: - twitterUsername := intentRecord.SendPrivatePaymentMetadata.TipMetadata.Username + if typed.TwoWayChat.IntentId == nil { + /* + isFriends, err := s.data.IsFriendshipPaid(ctx, owner, typed.TwoWayChat.OtherUser) + if err != nil { + log.WithError(err).Warn("failure checking two way chat") + return nil, status.Error(codes.Internal, "") + } - // The owner must still own the Twitter username - ownsUsername, err := s.ownsTwitterUsername(ctx, owner, twitterUsername) - if err != nil { - log.WithError(err).Warn("failure determing twitter username ownership") + if !isFriends { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + */ + _, err = s.data.GetChatByIdV2(ctx, chatId) + if errors.Is(err, chat.ErrChatNotFound) { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } else if err != nil { + log.WithError(err).Warn("failure checking two way chat") return nil, status.Error(codes.Internal, "") - } else if !ownsUsername { + } + } else { + intentId := base58.Encode(typed.TwoWayChat.IntentId.Value) + log = log.WithField("intent", intentId) + + intentRecord, err := s.data.GetIntent(ctx, intentId) + if errors.Is(err, intent.ErrIntentNotFound) { + log.WithError(err).Info("Intent not found") return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_DENIED, + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting intent record") + return nil, status.Error(codes.Internal, "") } - // todo: try to find an existing chat, but for now always create a new completely random one - var chatId chat.ChatId - rand.Read(chatId[:]) + if intentRecord.SendPrivatePaymentMetadata == nil { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } - creationTs := time.Now() + // TODO: Further verification + } - chatRecord := &chat.ChatRecord{ + // At this point, we assume the relationship is valid, and can proceed to recover or create + // the chat record. + creationTs := time.Now() + chatRecord := &chat.ChatRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + IsVerified: true, + CreatedAt: creationTs, + } + memberRecords := []*chat.MemberRecord{ + { ChatId: chatId, - ChatType: chat.ChatTypeTwoWay, - - IsVerified: true, + MemberId: chat.GenerateMemberId(), - CreatedAt: creationTs, - } - - memberRecords := []*chat.MemberRecord{ - { - ChatId: chatId, - MemberId: chat.GenerateMemberId(), + Platform: chat.PlatformTwitter, + PlatformId: req.Self.Username, + OwnerAccount: owner.PublicKey().ToBase58(), - Platform: chat.PlatformTwitter, - PlatformId: twitterUsername, - OwnerAccount: owner.PublicKey().ToBase58(), + JoinedAt: creationTs, + }, + { + ChatId: chatId, + MemberId: chat.GenerateMemberId(), - JoinedAt: creationTs, - }, - { - ChatId: chatId, - MemberId: chat.GenerateMemberId(), + Platform: chat.PlatformTwitter, + PlatformId: typed.TwoWayChat.Identity.Username, - Platform: chat.PlatformCode, - PlatformId: tipper.PublicKey().ToBase58(), + JoinedAt: time.Now(), + }, + } - JoinedAt: creationTs, - }, + err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { + chatRecord, err = s.data.GetChatByIdV2(ctx, chatId) + if err != nil && errors.Is(err, chat.ErrChatNotFound) { + return fmt.Errorf("failed to check existing chat: %w", err) } - err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { - err := s.data.PutChatV2(ctx, chatRecord) + if chatRecord != nil { + memberRecords, err = s.data.GetAllChatMembersV2(ctx, chatId) if err != nil { - return errors.Wrap(err, "error creating chat record") - } - - for _, memberRecord := range memberRecords { - err := s.data.PutChatMemberV2(ctx, memberRecord) - if err != nil { - return errors.Wrap(err, "error creating member record") - } + return fmt.Errorf("failed to get members of existing chat: %w", err) } return nil - }) - if err != nil { - log.WithError(err).Warn("failure creating new chat") - return nil, status.Error(codes.Internal, "") } - protoChat, err := s.toProtoChat( - ctx, - chatRecord, - memberRecords, - map[chat.Platform]string{ - chat.PlatformCode: owner.PublicKey().ToBase58(), - chat.PlatformTwitter: twitterUsername, - }, - ) - if err != nil { - log.WithError(err).Warn("failure constructing proto chat message") - return nil, status.Error(codes.Internal, "") + if err = s.data.PutChatV2(ctx, chatRecord); err != nil { + return fmt.Errorf("failed to save new chat: %w", err) + } + for _, m := range memberRecords { + if err = s.data.PutChatMemberV2(ctx, m); err != nil { + return fmt.Errorf("failed to add member to chat: %w", err) + } } - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_OK, - Chat: protoChat, - }, nil - default: - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_INVALID_PARAMETER, - Chat: nil, - }, nil + return nil + }) + if err != nil { + log.WithError(err).Warn("failure creating chat") + return nil, status.Error(codes.Internal, "") } + protoChat, err := s.toProtoChat( + ctx, + chatRecord, + memberRecords, + map[chat.Platform]string{ + chat.PlatformCode: owner.PublicKey().ToBase58(), + chat.PlatformTwitter: req.Self.Username, + }, + ) + if err != nil { + log.WithError(err).Warn("failure constructing proto chat message") + return nil, status.Error(codes.Internal, "") + } + + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_OK, + Chat: protoChat, + }, nil + default: return nil, status.Error(codes.InvalidArgument, "StartChatRequest.Parameters is nil") } diff --git a/pkg/code/server/grpc/transaction/v2/history_test.go b/pkg/code/server/grpc/transaction/v2/history_test.go index 3ad912b7..80ae0442 100644 --- a/pkg/code/server/grpc/transaction/v2/history_test.go +++ b/pkg/code/server/grpc/transaction/v2/history_test.go @@ -3,7 +3,6 @@ package transaction_v2 import ( "testing" - "github.com/mr-tron/base58" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -14,7 +13,6 @@ import ( chat_util "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" - chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" currency_lib "github.com/code-payments/code-server/pkg/currency" "github.com/code-payments/code-server/pkg/kin" timelock_token_v1 "github.com/code-payments/code-server/pkg/solana/timelock/v1" @@ -340,10 +338,6 @@ func TestPaymentHistory_HappyPath(t *testing.T) { require.NoError(t, err) require.Len(t, chatMessageRecords, 1) - chatMessageRecordsV2, err := server.data.GetAllChatMessagesV2(server.ctx, chat_v2.GetChatId(chat_util.TipsName, sendingPhone.parentAccount.PublicKey().ToBase58(), true)) - require.NoError(t, err) - requireEquivalent(t, chatMessageRecords, chatMessageRecordsV2) - protoChatMessage = getProtoChatMessage(t, chatMessageRecords[0]) require.Len(t, protoChatMessage.Content, 1) require.NotNil(t, protoChatMessage.Content[0].GetExchangeData()) @@ -419,10 +413,6 @@ func TestPaymentHistory_HappyPath(t *testing.T) { require.NoError(t, err) require.Len(t, chatMessageRecords, 1) - chatMessageRecordsV2, err = server.data.GetAllChatMessagesV2(server.ctx, chat_v2.GetChatId(chat_util.TipsName, receivingPhone.parentAccount.PublicKey().ToBase58(), true)) - require.NoError(t, err) - requireEquivalent(t, chatMessageRecords, chatMessageRecordsV2) - protoChatMessage = getProtoChatMessage(t, chatMessageRecords[0]) require.Len(t, protoChatMessage.Content, 1) require.NotNil(t, protoChatMessage.Content[0].GetExchangeData()) @@ -432,49 +422,3 @@ func TestPaymentHistory_HappyPath(t *testing.T) { assert.Equal(t, 45.6, protoChatMessage.Content[0].GetExchangeData().GetExact().NativeAmount) assert.Equal(t, kin.ToQuarks(456), protoChatMessage.Content[0].GetExchangeData().GetExact().Quarks) } - -func requireEquivalent(t *testing.T, v1 []*chat_v1.Message, v2 []*chat_v2.MessageRecord) { - require.Equal(t, len(v1), len(v2)) - - for i, v1Record := range v1 { - v2Record := v2[i] - - require.Equal(t, v1Record.ChatId[:], v2Record.ChatId[:]) - require.Equal(t, v1Record.IsSilent, v2Record.IsSilent) - - v1Message := getProtoChatMessage(t, v1Record) - require.Empty(t, v1Message.Sender) - - v2Message := getProtoChatMessageV2(t, v2Record) - require.Empty(t, v2Message.SenderId) - - require.Equal(t, len(v1Message.Content), len(v2Message.Content)) - - // TODO: Move this to somewhere common? - for c := range v1Message.Content { - a := v1Message.Content[c].GetExchangeData() - require.NotNil(t, a) - - b := v2Message.Content[c].GetExchangeData() - require.NotNil(t, b) - - require.EqualValues(t, a.Verb, b.Verb) - - if a.GetExact() != nil { - require.Equal(t, a.GetExact().Currency, b.GetExact().Currency) - require.Equal(t, a.GetExact().ExchangeRate, b.GetExact().ExchangeRate) - require.Equal(t, a.GetExact().NativeAmount, b.GetExact().NativeAmount) - require.Equal(t, a.GetExact().Quarks, b.GetExact().Quarks) - } else if a.GetPartial() != nil { - require.Equal(t, a.GetPartial().Currency, b.GetPartial().Currency) - require.Equal(t, a.GetPartial().NativeAmount, b.GetPartial().NativeAmount) - } else { - t.Fatal("Unhandled case") - } - - intent := b.GetIntent() - require.NotNil(t, intent) - require.Equal(t, v1Record.MessageId, base58.Encode(intent.Value)) - } - } -} From e4c78602d89dbc38732174007652f073f5d33b8a Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Tue, 17 Sep 2024 15:31:34 -0400 Subject: [PATCH 63/71] chat,user: pipe back FriendChatId in user, debug checks on StartChat --- go.mod | 2 ++ go.sum | 4 ++-- pkg/code/server/grpc/chat/v2/server.go | 13 +++++++++++-- pkg/code/server/grpc/user/server.go | 13 ++++++++++++- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index cc45d8ea..8b223a97 100644 --- a/go.mod +++ b/go.mod @@ -131,3 +131,5 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108 diff --git a/go.sum b/go.sum index 7414b428..2518c2e6 100644 --- a/go.sum +++ b/go.sum @@ -121,8 +121,6 @@ github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWH github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= -github.com/code-payments/code-protobuf-api v1.19.0 h1:md/eJhqltz8dDY0U8hwT/42C3h+kP+W/68D7RMSjqPo= -github.com/code-payments/code-protobuf-api v1.19.0/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6 h1:NmTXa/uVnDyp0TY5MKi197+3HWcnYWfnHGyaFthlnGw= github.com/containerd/continuity v0.0.0-20190827140505-75bee3e2ccb6/go.mod h1:GL3xCUCBDV3CZiTSEKksMWbLE66hEyuu9qyDOOqM47Y= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= @@ -425,6 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108 h1:pFXpZNkx7E1MfUyf3G6kD95L0aPry98WunNH6nszHWY= +github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 33781317..2e50e59e 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -3,6 +3,7 @@ package chat_v2 import ( "context" "database/sql" + "encoding/base64" "fmt" "math" "sync" @@ -603,7 +604,15 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil } - // TODO: Further verification + // TODO: Further verification + Enforcement + if !intentRecord.SendPrivatePaymentMetadata.IsChat { + log.Warn("intent is not for chat") + } + + expectedChatId := base64.StdEncoding.EncodeToString(chatId[:]) + if intentRecord.SendPrivatePaymentMetadata.ChatId != expectedChatId { + log.WithField("expected", expectedChatId).WithField("actual", intentRecord.SendPrivatePaymentMetadata.ChatId).Warn("chat id mismatch") + } } // At this point, we assume the relationship is valid, and can proceed to recover or create @@ -639,7 +648,7 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { chatRecord, err = s.data.GetChatByIdV2(ctx, chatId) - if err != nil && errors.Is(err, chat.ErrChatNotFound) { + if err != nil && !errors.Is(err, chat.ErrChatNotFound) { return fmt.Errorf("failed to check existing chat: %w", err) } diff --git a/pkg/code/server/grpc/user/server.go b/pkg/code/server/grpc/user/server.go index 1b24adf6..c4f536cf 100644 --- a/pkg/code/server/grpc/user/server.go +++ b/pkg/code/server/grpc/user/server.go @@ -22,6 +22,7 @@ import ( "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" "github.com/code-payments/code-server/pkg/code/data/account" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/intent" "github.com/code-payments/code-server/pkg/code/data/paymentrequest" "github.com/code-payments/code-server/pkg/code/data/phone" @@ -701,6 +702,12 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit return nil, status.Error(codes.Internal, "") } + var friendChatId *commonpb.ChatId + if req.Requestor != nil { + // TODO: Validate the requestor + friendChatId = chat.GetChatId(base58.Encode(req.Requestor.Value), tipAddress.PublicKey().ToBase58(), true).ToProto() + } + return &userpb.GetTwitterUserResponse{ Result: userpb.GetTwitterUserResponse_OK, TwitterUser: &userpb.TwitterUser{ @@ -710,6 +717,11 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit ProfilePicUrl: record.ProfilePicUrl, VerifiedType: record.VerifiedType, FollowerCount: record.FollowerCount, + FriendshipCost: &transactionpb.ExchangeDataWithoutRate{ + Currency: "usd", + NativeAmount: 1.0, + }, + FriendChatId: friendChatId, }, }, nil case twitter.ErrUserNotFound: @@ -720,7 +732,6 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit log.WithError(err).Warn("failure getting twitter user info") return nil, status.Error(codes.Internal, "") } - } func (s *identityServer) markWebhookAsPending(ctx context.Context, id string) error { From 21d87844d837d00cb8648ec097a37c4eabe315b3 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Tue, 17 Sep 2024 15:44:55 -0400 Subject: [PATCH 64/71] chat: fix build --- pkg/code/data/chat/v2/id.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index 5f5d54cb..e3a17066 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -12,6 +12,7 @@ import ( "github.com/pkg/errors" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" ) type ChatId [32]byte @@ -51,7 +52,7 @@ func GetChatIdFromString(value string) (ChatId, error) { } // GetChatIdFromProto gets a chat ID from the protobuf variant -func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { +func GetChatIdFromProto(proto *commonpb.ChatId) (ChatId, error) { if err := proto.Validate(); err != nil { return ChatId{}, errors.Wrap(err, "proto validation failed") } @@ -60,8 +61,8 @@ func GetChatIdFromProto(proto *chatpb.ChatId) (ChatId, error) { } // ToProto converts a chat ID to its protobuf variant -func (c ChatId) ToProto() *chatpb.ChatId { - return &chatpb.ChatId{Value: c[:]} +func (c ChatId) ToProto() *commonpb.ChatId { + return &commonpb.ChatId{Value: c[:]} } // Validate validates a chat ID From 2aa7ad8142347d1c9f03b31abb2317230051dab7 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 18 Sep 2024 10:16:24 -0400 Subject: [PATCH 65/71] chat: don't overwrite initial ChatRecord on dupe check. --- go.mod | 2 +- go.sum | 4 ++-- pkg/code/server/grpc/chat/v2/server.go | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index 8b223a97..0fefd303 100644 --- a/go.mod +++ b/go.mod @@ -132,4 +132,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108 +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c diff --git a/go.sum b/go.sum index 2518c2e6..80e0bd0c 100644 --- a/go.sum +++ b/go.sum @@ -423,8 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108 h1:pFXpZNkx7E1MfUyf3G6kD95L0aPry98WunNH6nszHWY= -github.com/mfycheng/code-protobuf-api v0.0.0-20240917192150-82a26e4a0108/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c h1:xGrUQGe2WTFEQ4+S0Fy1JYf7m/VoiNp6Ubckod7MU2I= +github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 2e50e59e..445c9cc9 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -647,12 +647,13 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* } err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { - chatRecord, err = s.data.GetChatByIdV2(ctx, chatId) + existingChatRecord, err := s.data.GetChatByIdV2(ctx, chatId) if err != nil && !errors.Is(err, chat.ErrChatNotFound) { return fmt.Errorf("failed to check existing chat: %w", err) } if chatRecord != nil { + chatRecord = existingChatRecord memberRecords, err = s.data.GetAllChatMembersV2(ctx, chatId) if err != nil { return fmt.Errorf("failed to get members of existing chat: %w", err) From 0ef896a3b85832ad920992675d3f41c26cdaf5cc Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 18 Sep 2024 10:22:14 -0400 Subject: [PATCH 66/71] chat: fix existing chat condition --- pkg/code/server/grpc/chat/v2/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 445c9cc9..f3e8737d 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -652,7 +652,7 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* return fmt.Errorf("failed to check existing chat: %w", err) } - if chatRecord != nil { + if existingChatRecord != nil { chatRecord = existingChatRecord memberRecords, err = s.data.GetAllChatMembersV2(ctx, chatId) if err != nil { From 726767259770e1c51b5e5f5c0901cc02c49f5239 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 18 Sep 2024 10:37:54 -0400 Subject: [PATCH 67/71] chat: hack to bypass member validation --- pkg/code/server/grpc/chat/v2/server.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index f3e8737d..ba4551a4 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -639,8 +639,9 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* ChatId: chatId, MemberId: chat.GenerateMemberId(), - Platform: chat.PlatformTwitter, - PlatformId: typed.TwoWayChat.Identity.Username, + Platform: chat.PlatformTwitter, + PlatformId: typed.TwoWayChat.Identity.Username, + OwnerAccount: base58.Encode(typed.TwoWayChat.OtherUser.Value), JoinedAt: time.Now(), }, From 1398a235bcc8c7d10d7e7ccbdca3ac9cde8a22dc Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Wed, 18 Sep 2024 11:34:48 -0400 Subject: [PATCH 68/71] chat: return profilepic urls --- go.mod | 2 +- go.sum | 4 ++-- pkg/code/server/grpc/chat/v2/server.go | 13 +++++++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index 0fefd303..82c08e99 100644 --- a/go.mod +++ b/go.mod @@ -132,4 +132,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f diff --git a/go.sum b/go.sum index 80e0bd0c..b6327cdf 100644 --- a/go.sum +++ b/go.sum @@ -423,8 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c h1:xGrUQGe2WTFEQ4+S0Fy1JYf7m/VoiNp6Ubckod7MU2I= -github.com/mfycheng/code-protobuf-api v0.0.0-20240918135149-d5567bcb2c3c/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f h1:x7RVQ91vIh2EpMmCtaTMnKITD7v0Rj4vErH5fQb4trQ= +github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index ba4551a4..cbdc2002 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1281,9 +1281,18 @@ func (s *Server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] isSelf = ok && myTwitterUsername == memberRecord.PlatformId + profilePicUrl := "" + user, err := s.data.GetTwitterUserByUsername(ctx, memberRecord.PlatformId) + if err != nil { + s.log.WithError(err).WithField("method", "toProtoChat").Warn("Failed to get twitter user for member record") + } else { + profilePicUrl = user.ProfilePicUrl + } + identity = &chatpb.ChatMemberIdentity{ - Platform: memberRecord.Platform.ToProto(), - Username: memberRecord.PlatformId, + Platform: memberRecord.Platform.ToProto(), + Username: memberRecord.PlatformId, + ProfilePicUrl: profilePicUrl, } default: return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) From cf233cb02fb8deae8096707bb528d3d77f2f2dff Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Mon, 30 Sep 2024 12:21:17 -0400 Subject: [PATCH 69/71] chat overhaul, needs a new branch --- go.mod | 2 +- go.sum | 4 +- pkg/code/chat/notifier.go | 4 +- pkg/code/chat/sender.go | 141 -- pkg/code/common/account.go | 22 + pkg/code/data/chat/v2/id.go | 74 +- pkg/code/data/chat/v2/id_test.go | 8 - pkg/code/data/chat/v2/memory/store.go | 675 ++++------ pkg/code/data/chat/v2/memory/store_test.go | 532 +++++++- pkg/code/data/chat/v2/model.go | 330 ++--- pkg/code/data/chat/v2/store.go | 104 +- pkg/code/data/chat/v2/tests/tests.go | 14 - pkg/code/data/internal.go | 76 +- pkg/code/push/notifications.go | 11 +- pkg/code/server/grpc/chat/v2/server.go | 1173 ++++++----------- pkg/code/server/grpc/chat/v2/server_test.go | 375 ++++++ .../grpc/transaction/v2/intent_handler.go | 9 + .../server/grpc/transaction/v2/testutil.go | 8 - pkg/code/server/grpc/user/server.go | 14 +- pkg/database/query/cursor.go | 2 +- pkg/pointer/pointer.go | 9 + pkg/testutil/proto.go | 33 + 22 files changed, 1896 insertions(+), 1724 deletions(-) delete mode 100644 pkg/code/data/chat/v2/tests/tests.go create mode 100644 pkg/testutil/proto.go diff --git a/go.mod b/go.mod index 82c08e99..2463acd3 100644 --- a/go.mod +++ b/go.mod @@ -132,4 +132,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8 diff --git a/go.sum b/go.sum index b6327cdf..16246996 100644 --- a/go.sum +++ b/go.sum @@ -423,8 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f h1:x7RVQ91vIh2EpMmCtaTMnKITD7v0Rj4vErH5fQb4trQ= -github.com/mfycheng/code-protobuf-api v0.0.0-20240918153306-b04180279f5f/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8 h1:cP0i0oAMtWtyBP0wMOuVOzg2i3dYQZOuq2CtXrgr8iM= +github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/chat/notifier.go b/pkg/code/chat/notifier.go index 15839aa2..7fa4c1ff 100644 --- a/pkg/code/chat/notifier.go +++ b/pkg/code/chat/notifier.go @@ -9,7 +9,7 @@ import ( ) type Notifier interface { - NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) + NotifyMessage(ctx context.Context, chatID chat.ChatId, message *chatpb.Message) } type NoopNotifier struct{} @@ -18,5 +18,5 @@ func NewNoopNotifier() *NoopNotifier { return &NoopNotifier{} } -func (n *NoopNotifier) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.ChatMessage) { +func (n *NoopNotifier) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.Message) { } diff --git a/pkg/code/chat/sender.go b/pkg/code/chat/sender.go index 4b56f31c..f63436a8 100644 --- a/pkg/code/chat/sender.go +++ b/pkg/code/chat/sender.go @@ -3,20 +3,15 @@ package chat import ( "context" "errors" - "fmt" "time" "github.com/mr-tron/base58" - "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" - chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" - "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" - chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" ) // SendNotificationChatMessageV1 sends a chat message to a receiving owner account. @@ -119,139 +114,3 @@ func SendNotificationChatMessageV1( return canPushMessage, nil } - -func SendNotificationChatMessageV2( - ctx context.Context, - data code_data.Provider, - notifier Notifier, - chatTitle string, - isVerifiedChat bool, - receiver *common.Account, - protoMessage *chatv2pb.ChatMessage, - intentId string, - isSilentMessage bool, -) (canPushMessage bool, err error) { - log := logrus.StandardLogger().WithField("type", "sendNotificationChatMessageV2") - - chatId := chat_v2.GetChatId(chatTitle, receiver.PublicKey().ToBase58(), isVerifiedChat) - - if protoMessage.Cursor != nil { - // Let the utilities and GetMessages RPC handle cursors - return false, errors.New("cursor must not be set") - } - - if err := protoMessage.Validate(); err != nil { - return false, err - } - - messageId, err := chat_v2.GetMessageIdFromProto(protoMessage.MessageId) - if err != nil { - return false, fmt.Errorf("invalid message id: %w", err) - } - - // Clear out extracted metadata as a space optimization - cloned := proto.Clone(protoMessage).(*chatv2pb.ChatMessage) - cloned.MessageId = nil - cloned.Ts = nil - cloned.Cursor = nil - - marshalled, err := proto.Marshal(cloned) - if err != nil { - return false, err - } - - canPersistMessage := true - canPushMessage = !isSilentMessage - - // - // Step 1: Check to see if we need to create the chat. - // - _, err = data.GetChatByIdV2(ctx, chatId) - if errors.Is(err, chat_v2.ErrChatNotFound) { - chatRecord := &chat_v2.ChatRecord{ - ChatId: chatId, - ChatType: chat_v2.ChatTypeNotification, - ChatTitle: &chatTitle, - IsVerified: isVerifiedChat, - - CreatedAt: time.Now(), - } - - // TODO: These should be run in a transaction, but so far - // we're being called in a transaction. We should have some kind - // of safety check here... - err = data.PutChatV2(ctx, chatRecord) - if err != nil && !errors.Is(err, chat_v2.ErrChatExists) { - return false, fmt.Errorf("failed to initialize chat: %w", err) - } - - memberId := chat_v2.GenerateMemberId() - err = data.PutChatMemberV2(ctx, &chat_v2.MemberRecord{ - ChatId: chatId, - MemberId: memberId, - Platform: chat_v2.PlatformCode, - PlatformId: receiver.PublicKey().ToBase58(), - JoinedAt: time.Now(), - }) - if err != nil { - return false, fmt.Errorf("failed to initialize chat with member: %w", err) - } - - log.WithFields(logrus.Fields{ - "chat_id": chatId.String(), - "member": memberId.String(), - "platform_id": receiver.PublicKey().ToBase58(), - }).Info("Initialized chat for tip") - - } else if err != nil { - return false, err - } - - // - // Step 2: Ensure that there is exactly 1 member in the chat. - // - members, err := data.GetAllChatMembersV2(ctx, chatId) - if errors.Is(err, chat_v2.ErrMemberNotFound) { // TODO: This is a weird error... - return false, nil - } else if err != nil { - return false, err - } - if len(members) > 1 { - // TODO: This _could_ get weird if client or someone else decides to join as another member. - return false, errors.New("notification chat should have at most 1 member") - } - - canPersistMessage = !members[0].IsUnsubscribed - canPushMessage = canPushMessage && canPersistMessage && !members[0].IsMuted - - if canPersistMessage { - refType := chat_v2.ReferenceTypeIntent - messageRecord := &chat_v2.MessageRecord{ - ChatId: chatId, - MessageId: messageId, - - Data: marshalled, - IsSilent: isSilentMessage, - - ReferenceType: &refType, - Reference: &intentId, - } - - // TODO: Once we have a better idea on the data modeling around chatv2, - // we may wish to have the server manage the creation of messages - // (and chats?) as well. That would also put the - err = data.PutChatMessageV2(ctx, messageRecord) - if err != nil { - return false, err - } - - notifier.NotifyMessage(ctx, chatId, protoMessage) - - log.WithField("chat_id", chatId.String()).Info("Put and notified") - } - - // TODO: Once we move more things over to chatv2, we will need to increment - // badge count here. We don't currently, as it would result in a double - // push. - return canPushMessage, nil -} diff --git a/pkg/code/common/account.go b/pkg/code/common/account.go index d48d971d..417e3f4a 100644 --- a/pkg/code/common/account.go +++ b/pkg/code/common/account.go @@ -5,6 +5,7 @@ import ( "context" "crypto/ed25519" "fmt" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/pkg/errors" "github.com/sirupsen/logrus" @@ -196,6 +197,27 @@ func (a *Account) ToTimelockVault(dataVersion timelock_token_v1.TimelockDataVers return timelockAccounts.Vault, nil } +func (a *Account) ToMessagingAccount(mint *Account) (*Account, error) { + return a.ToTimelockVault(timelock_token_v1.DataVersion1, mint) +} + +func (a *Account) ToChatMemberId() (chat.MemberId, error) { + messagingAccount, err := a.ToMessagingAccount(KinMintAccount) + if err != nil { + return nil, err + } + + return messagingAccount.PublicKey().ToBytes(), nil +} + +func (a *Account) MustToChatMemberId() chat.MemberId { + id, err := a.ToChatMemberId() + if err != nil { + panic(err) + } + return id +} + func (a *Account) ToAssociatedTokenAccount(mint *Account) (*Account, error) { if err := a.Validate(); err != nil { return nil, errors.Wrap(err, "error validating owner account") diff --git a/pkg/code/data/chat/v2/id.go b/pkg/code/data/chat/v2/id.go index e3a17066..2aeed440 100644 --- a/pkg/code/data/chat/v2/id.go +++ b/pkg/code/data/chat/v2/id.go @@ -5,10 +5,10 @@ import ( "crypto/sha256" "encoding/hex" "fmt" - "strings" "time" "github.com/google/uuid" + "github.com/mr-tron/base58" "github.com/pkg/errors" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" @@ -33,22 +33,20 @@ func GetChatIdFromBytes(buffer []byte) (ChatId, error) { return typed, nil } -func GetChatId(sender, receiver string, isVerified bool) ChatId { - combined := []byte(fmt.Sprintf("%s:%s:%v", sender, receiver, isVerified)) - if strings.Compare(sender, receiver) > 0 { - combined = []byte(fmt.Sprintf("%s:%s:%v", receiver, sender, isVerified)) +// GetTwoWayChatId returns the ChatId for two users. +func GetTwoWayChatId(sender, receiver []byte) ChatId { + var a, b []byte + if bytes.Compare(sender, receiver) <= 0 { + a, b = sender, receiver + } else { + a, b = receiver, sender } - return sha256.Sum256(combined) -} -// GetChatIdFromBytes gets a chat ID from the string representation -func GetChatIdFromString(value string) (ChatId, error) { - decoded, err := hex.DecodeString(value) - if err != nil { - return ChatId{}, errors.Wrap(err, "value is not a hexadecimal string") - } + combined := make([]byte, len(a)+len(b)) + copy(combined, a) + copy(combined[len(a):], b) - return GetChatIdFromBytes(decoded) + return sha256.Sum256(combined) } // GetChatIdFromProto gets a chat ID from the protobuf variant @@ -82,21 +80,15 @@ func (c ChatId) String() string { return hex.EncodeToString(c[:]) } -// Random UUIDv4 ID for chat members -type MemberId uuid.UUID - -// GenerateMemberId generates a new random chat member ID -func GenerateMemberId() MemberId { - return MemberId(uuid.New()) -} +type MemberId []byte // GetMemberIdFromBytes gets a member ID from a byte buffer func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { - if len(buffer) != 16 { - return MemberId{}, errors.New("member id must be 16 bytes in length") + if len(buffer) != 32 { + return MemberId{}, errors.New("member id must be 32 bytes in length") } - var typed MemberId + typed := make(MemberId, len(buffer)) copy(typed[:], buffer[:]) if err := typed.Validate(); err != nil { @@ -108,16 +100,16 @@ func GetMemberIdFromBytes(buffer []byte) (MemberId, error) { // GetMemberIdFromString gets a chat member ID from the string representation func GetMemberIdFromString(value string) (MemberId, error) { - decoded, err := uuid.Parse(value) + b, err := base58.Decode(value) if err != nil { - return MemberId{}, errors.Wrap(err, "value is not a uuid string") + return MemberId{}, errors.Wrap(err, "invalid member id") } - return GetMemberIdFromBytes(decoded[:]) + return GetMemberIdFromBytes(b) } // GetMemberIdFromProto gets a member ID from the protobuf variant -func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { +func GetMemberIdFromProto(proto *chatpb.MemberId) (MemberId, error) { if err := proto.Validate(); err != nil { return MemberId{}, errors.Wrap(err, "proto validation failed") } @@ -126,16 +118,14 @@ func GetMemberIdFromProto(proto *chatpb.ChatMemberId) (MemberId, error) { } // ToProto converts a message ID to its protobuf variant -func (m MemberId) ToProto() *chatpb.ChatMemberId { - return &chatpb.ChatMemberId{Value: m[:]} +func (m MemberId) ToProto() *chatpb.MemberId { + return &chatpb.MemberId{Value: m[:]} } // Validate validates a chat member ID func (m MemberId) Validate() error { - casted := uuid.UUID(m) - - if casted.Version() != 4 { - return errors.Errorf("invalid uuid version: %s", casted.Version().String()) + if l := len(m); l < 0 || l > 32 { + return fmt.Errorf("member id must be in range 0-32, got: %d", l) } return nil @@ -143,17 +133,17 @@ func (m MemberId) Validate() error { // Clone clones a chat member ID func (m MemberId) Clone() MemberId { - var cloned MemberId + cloned := make(MemberId, len(m)) copy(cloned[:], m[:]) return cloned } // String returns the string representation of a MemberId func (m MemberId) String() string { - return uuid.UUID(m).String() + return base58.Encode(m[:]) } -// Time-based UUIDv7 ID for chat messages +// MessageId is a time-based UUIDv7 ID for chat messages type MessageId uuid.UUID // GenerateMessageId generates a UUIDv7 message ID using the current time @@ -184,7 +174,7 @@ func GenerateMessageIdAtTime(ts time.Time) MessageId { randomUUID := uuid.New() copy(uuidBytes[7:], randomUUID[7:]) - return MessageId(uuidBytes) + return uuidBytes } // GetMessageIdFromBytes gets a message ID from a byte buffer @@ -214,7 +204,7 @@ func GetMessageIdFromString(value string) (MessageId, error) { } // GetMessageIdFromProto gets a message ID from the protobuf variant -func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { +func GetMessageIdFromProto(proto *chatpb.MessageId) (MessageId, error) { if err := proto.Validate(); err != nil { return MessageId{}, errors.Wrap(err, "proto validation failed") } @@ -223,8 +213,8 @@ func GetMessageIdFromProto(proto *chatpb.ChatMessageId) (MessageId, error) { } // ToProto converts a message ID to its protobuf variant -func (m MessageId) ToProto() *chatpb.ChatMessageId { - return &chatpb.ChatMessageId{Value: m[:]} +func (m MessageId) ToProto() *chatpb.MessageId { + return &chatpb.MessageId{Value: m[:]} } // GetTimestamp gets the encoded timestamp in the message ID @@ -253,7 +243,7 @@ func (m MessageId) Before(other MessageId) bool { return m.Compare(other) < 0 } -// Before returns whether the message ID is after the provided value +// After returns whether the message ID is after the provided value func (m MessageId) After(other MessageId) bool { return m.Compare(other) > 0 } diff --git a/pkg/code/data/chat/v2/id_test.go b/pkg/code/data/chat/v2/id_test.go index 27f1f359..3a9b710c 100644 --- a/pkg/code/data/chat/v2/id_test.go +++ b/pkg/code/data/chat/v2/id_test.go @@ -9,14 +9,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestGenerateMemberId_Validation(t *testing.T) { - valid := GenerateMemberId() - assert.NoError(t, valid.Validate()) - - invalid := MemberId(GenerateMessageId()) - assert.Error(t, invalid.Validate()) -} - func TestGenerateMessageId_Validation(t *testing.T) { valid := GenerateMessageId() assert.NoError(t, valid.Validate()) diff --git a/pkg/code/data/chat/v2/memory/store.go b/pkg/code/data/chat/v2/memory/store.go index ff54f9ac..164cdc9c 100644 --- a/pkg/code/data/chat/v2/memory/store.go +++ b/pkg/code/data/chat/v2/memory/store.go @@ -3,539 +3,416 @@ package memory import ( "bytes" "context" + "slices" "sort" + "strings" "sync" - "time" - - "github.com/pkg/errors" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/database/query" ) -// todo: finish implementing me -type store struct { - mu sync.Mutex - - chatRecords []*chat.ChatRecord - memberRecords []*chat.MemberRecord - messageRecords []*chat.MessageRecord - - lastChatId int64 - lastMemberId int64 - lastMessageId int64 +type InMemoryStore struct { + mu sync.RWMutex + chats map[string]*chat.MetadataRecord + members map[string]map[string]*chat.MemberRecord + messages map[string][]*chat.MessageRecord } -// New returns a new in memory chat.Store -func New() chat.Store { - return &store{} -} - -// GetChatById implements chat.Store.GetChatById -func (s *store) GetChatById(_ context.Context, chatId chat.ChatId) (*chat.ChatRecord, error) { - s.mu.Lock() - defer s.mu.Unlock() - - item := s.findChatById(chatId) - if item == nil { - return nil, chat.ErrChatNotFound +func New() *InMemoryStore { + return &InMemoryStore{ + chats: make(map[string]*chat.MetadataRecord), + members: make(map[string]map[string]*chat.MemberRecord), + messages: make(map[string][]*chat.MessageRecord), } - - cloned := item.Clone() - return &cloned, nil } -// GetMemberById implements chat.Store.GetMemberById -func (s *store) GetMemberById(_ context.Context, chatId chat.ChatId, memberId chat.MemberId) (*chat.MemberRecord, error) { - s.mu.Lock() - defer s.mu.Unlock() - - item := s.findMemberById(chatId, memberId) - if item == nil { - return nil, chat.ErrMemberNotFound +// GetChatMetadata retrieves the metadata record for a specific chat +func (s *InMemoryStore) GetChatMetadata(_ context.Context, chatId chat.ChatId) (*chat.MetadataRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err } - cloned := item.Clone() - return &cloned, nil -} - -// GetMessageById implements chat.Store.GetMessageById -func (s *store) GetMessageById(_ context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() - item := s.findMessageById(chatId, messageId) - if item == nil { - return nil, chat.ErrMessageNotFound + if md, exists := s.chats[string(chatId[:])]; exists { + cloned := md.Clone() + return &cloned, nil } - cloned := item.Clone() - return &cloned, nil + return nil, chat.ErrChatNotFound } -// GetAllMembersByChatId implements chat.Store.GetAllMembersByChatId -func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) { - items := s.findMembersByChatId(chatId) - if len(items) == 0 { - return nil, chat.ErrMemberNotFound +// GetChatMessageV2 retrieves a specific message from a chat +func (s *InMemoryStore) GetChatMessageV2(_ context.Context, chatId chat.ChatId, messageId chat.MessageId) (*chat.MessageRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err } - return cloneMemberRecords(items), nil -} - -// GetAllMembersByPlatformIds implements chat.store.GetAllMembersByPlatformIds -func (s *store) GetAllMembersByPlatformIds(_ context.Context, idByPlatform map[chat.Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { - s.mu.Lock() - defer s.mu.Unlock() - - items := s.findMembersByPlatformIds(idByPlatform) - items, err := s.getMemberRecordPage(items, cursor, direction, limit) - if err != nil { + if err := messageId.Validate(); err != nil { return nil, err } - if len(items) == 0 { - return nil, chat.ErrMemberNotFound - } - return cloneMemberRecords(items), nil -} + s.mu.RLock() + defer s.mu.RUnlock() -// GetUnreadCount implements chat.store.GetUnreadCount -func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) { - s.mu.Lock() - defer s.mu.Unlock() + if messages, exists := s.messages[string(chatId[:])]; exists { + for _, message := range messages { + if bytes.Equal(message.MessageId[:], messageId[:]) { + clone := message.Clone() + return &clone, nil + } + } + } - items := s.findMessagesByChatId(chatId) - items = s.filterMessagesAfter(items, readPointer) - items = s.filterMessagesNotSentBy(items, memberId) - items = s.filterNotifiedMessages(items) - return uint32(len(items)), nil + return nil, chat.ErrMessageNotFound } -// GetAllMessagesByChatId implements chat.Store.GetAllMessagesByChatId -func (s *store) GetAllMessagesByChatId(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { - s.mu.Lock() - defer s.mu.Unlock() +// GetAllChatsForUserV2 retrieves all chat IDs that a given user belongs to +func (s *InMemoryStore) GetAllChatsForUserV2(_ context.Context, user chat.MemberId, opts ...query.Option) ([]chat.ChatId, error) { + if err := user.Validate(); err != nil { + return nil, err + } - items := s.findMessagesByChatId(chatId) - items, err := s.getMessageRecordPage(items, cursor, direction, limit) + qo := &query.QueryOptions{ + Supported: query.CanQueryByCursor | query.CanLimitResults | query.CanSortBy, + } + err := qo.Apply(opts...) if err != nil { return nil, err } - if len(items) == 0 { - return nil, chat.ErrMessageNotFound - } - return cloneMessageRecords(items), nil -} + s.mu.RLock() + defer s.mu.RUnlock() -// PutChat creates a new chat -func (s *store) PutChat(_ context.Context, record *chat.ChatRecord) error { - if err := record.Validate(); err != nil { - return err + var chatIds []chat.ChatId + for chatIdStr, members := range s.members { + if _, exists := members[user.String()]; exists { + chatId, _ := chat.GetChatIdFromBytes([]byte(chatIdStr)) + chatIds = append(chatIds, chatId) + } } - s.mu.Lock() - defer s.mu.Unlock() - - s.lastChatId++ + // Sort the chatIds + sort.Slice(chatIds, func(i, j int) bool { + if qo.SortBy == query.Descending { + return bytes.Compare(chatIds[i][:], chatIds[j][:]) > 0 + } + return bytes.Compare(chatIds[i][:], chatIds[j][:]) < 0 + }) - if item := s.findChat(record); item != nil { - return chat.ErrChatExists + // Apply cursor if provided + if qo.Cursor != nil { + cursorChatId, err := chat.GetChatIdFromBytes(qo.Cursor) + if err != nil { + return nil, err + } + var filteredChatIds []chat.ChatId + for _, chatId := range chatIds { + if qo.SortBy == query.Descending { + if bytes.Compare(chatId[:], cursorChatId[:]) < 0 { + filteredChatIds = append(filteredChatIds, chatId) + } + } else { + if bytes.Compare(chatId[:], cursorChatId[:]) > 0 { + filteredChatIds = append(filteredChatIds, chatId) + } + } + } + chatIds = filteredChatIds } - record.Id = s.lastChatId - if record.CreatedAt.IsZero() { - record.CreatedAt = time.Now() + // Apply limit if provided + if qo.Limit > 0 && uint64(len(chatIds)) > qo.Limit { + chatIds = chatIds[:qo.Limit] } - cloned := record.Clone() - s.chatRecords = append(s.chatRecords, &cloned) - - return nil + return chatIds, nil } -// PutMember creates a new chat member -func (s *store) PutMember(_ context.Context, record *chat.MemberRecord) error { - if err := record.Validate(); err != nil { - return err +// GetAllChatMessagesV2 retrieves all messages for a specific chat +func (s *InMemoryStore) GetAllChatMessagesV2(_ context.Context, chatId chat.ChatId, opts ...query.Option) ([]*chat.MessageRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err } - s.mu.Lock() - defer s.mu.Unlock() - - s.lastMemberId++ - - if item := s.findMember(record); item != nil { - return chat.ErrMemberExists + qo := &query.QueryOptions{ + Supported: query.CanLimitResults | query.CanSortBy | query.CanQueryByCursor, } - - record.Id = s.lastMemberId - if record.JoinedAt.IsZero() { - record.JoinedAt = time.Now() + if err := qo.Apply(opts...); err != nil { + return nil, err } - cloned := record.Clone() - s.memberRecords = append(s.memberRecords, &cloned) - - return nil -} + s.mu.RLock() + defer s.mu.RUnlock() -// PutMessage implements chat.Store.PutMessage -func (s *store) PutMessage(_ context.Context, record *chat.MessageRecord) error { - if err := record.Validate(); err != nil { - return err + messages, exists := s.messages[string(chatId[:])] + if !exists { + return nil, nil } - s.mu.Lock() - defer s.mu.Unlock() + var result []*chat.MessageRecord + for _, msg := range messages { + cloned := msg.Clone() + result = append(result, &cloned) + } - s.lastMessageId++ + // Sort the messages + sort.Slice(result, func(i, j int) bool { + if qo.SortBy == query.Descending { + return bytes.Compare(result[i].MessageId[:], result[j].MessageId[:]) > 0 + } + return bytes.Compare(result[i].MessageId[:], result[j].MessageId[:]) < 0 + }) - if item := s.findMessage(record); item != nil { - return chat.ErrMessageExsits + // Apply cursor if provided + if len(qo.Cursor) > 0 { + cursorMessageId, err := chat.GetMessageIdFromBytes(qo.Cursor) + if err != nil { + return nil, err + } + var filteredMessages []*chat.MessageRecord + for _, msg := range result { + if qo.SortBy == query.Descending { + if bytes.Compare(msg.MessageId[:], cursorMessageId[:]) < 0 { + filteredMessages = append(filteredMessages, msg) + } + } else { + if bytes.Compare(msg.MessageId[:], cursorMessageId[:]) > 0 { + filteredMessages = append(filteredMessages, msg) + } + } + } + result = filteredMessages } - record.Id = s.lastMessageId - - cloned := record.Clone() - s.messageRecords = append(s.messageRecords, &cloned) + // Apply limit if provided + if qo.Limit > 0 && uint64(len(result)) > qo.Limit { + result = result[:qo.Limit] + } - return nil + return result, nil } -// AdvancePointer implements chat.Store.AdvancePointer -func (s *store) AdvancePointer(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) (bool, error) { - switch pointerType { - case chat.PointerTypeDelivered, chat.PointerTypeRead: - default: - return false, chat.ErrInvalidPointerType +// GetChatMembersV2 retrieves all members of a specific chat +func (s *InMemoryStore) GetChatMembersV2(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) { + if err := chatId.Validate(); err != nil { + return nil, err } - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() - item := s.findMemberById(chatId, memberId) - if item == nil { - return false, chat.ErrMemberNotFound + members, exists := s.members[string(chatId[:])] + if !exists { + return nil, chat.ErrChatNotFound } - var currentPointer *chat.MessageId - switch pointerType { - case chat.PointerTypeDelivered: - currentPointer = item.DeliveryPointer - case chat.PointerTypeRead: - currentPointer = item.ReadPointer + var result []*chat.MemberRecord + for _, member := range members { + cloned := member.Clone() + result = append(result, &cloned) } - if currentPointer == nil || currentPointer.Before(pointer) { - switch pointerType { - case chat.PointerTypeDelivered: - cloned := pointer.Clone() - item.DeliveryPointer = &cloned - case chat.PointerTypeRead: - cloned := pointer.Clone() - item.ReadPointer = &cloned - } + slices.SortFunc(result, func(a, b *chat.MemberRecord) int { + return strings.Compare(a.MemberId, b.MemberId) + }) - return true, nil - } - return false, nil + return result, nil } -// UpgradeIdentity implements chat.Store.UpgradeIdentity -func (s *store) UpgradeIdentity(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, platform chat.Platform, platformId string) error { - switch platform { - case chat.PlatformTwitter: - default: - return errors.Errorf("platform not supported for identity upgrades: %s", platform.String()) +// IsChatMember checks if a given member is part of a specific chat +func (s *InMemoryStore) IsChatMember(_ context.Context, chatId chat.ChatId, memberId chat.MemberId) (bool, error) { + if err := chatId.Validate(); err != nil { + return false, err + } + if err := memberId.Validate(); err != nil { + return false, err } - s.mu.Lock() - defer s.mu.Unlock() + s.mu.RLock() + defer s.mu.RUnlock() - item := s.findMemberById(chatId, memberId) - if item == nil { - return chat.ErrMemberNotFound - } - if item.Platform != chat.PlatformCode { - return chat.ErrMemberIdentityAlreadyUpgraded + if members, exists := s.members[string(chatId[:])]; exists { + _, exists = members[memberId.String()] + return exists, nil } - item.Platform = platform - item.PlatformId = platformId - - return nil + return false, nil } -// SetMuteState implements chat.Store.SetMuteState -func (s *store) SetMuteState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { +// PutChatV2 stores or updates the metadata for a specific chat +func (s *InMemoryStore) PutChatV2(_ context.Context, record *chat.MetadataRecord) error { + if err := record.Validate(); err != nil { + return err + } + s.mu.Lock() defer s.mu.Unlock() - item := s.findMemberById(chatId, memberId) - if item == nil { - return chat.ErrMemberNotFound + if _, exists := s.chats[string(record.ChatId[:])]; exists { + return chat.ErrChatExists } - item.IsMuted = isMuted + cloned := record.Clone() + s.chats[string(record.ChatId[:])] = &cloned return nil } -// SetSubscriptionState implements chat.Store.SetSubscriptionState -func (s *store) SetSubscriptionState(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, isSubscribed bool) error { +// PutChatMemberV2 stores or updates a member record for a specific chat +func (s *InMemoryStore) PutChatMemberV2(_ context.Context, record *chat.MemberRecord) error { + if err := record.Validate(); err != nil { + return err + } + s.mu.Lock() defer s.mu.Unlock() - item := s.findMemberById(chatId, memberId) - if item == nil { - return chat.ErrMemberNotFound + members, exists := s.members[string(record.ChatId[:])] + if !exists { + members = make(map[string]*chat.MemberRecord) + s.members[string(record.ChatId[:])] = members } - item.IsUnsubscribed = !isSubscribed - - return nil -} + if _, exists = members[record.MemberId]; exists { + return chat.ErrMemberExists + } -func (s *store) findChat(data *chat.ChatRecord) *chat.ChatRecord { - for _, item := range s.chatRecords { - if data.Id == item.Id { - return item - } + cloned := record.Clone() + members[record.MemberId] = &cloned - if bytes.Equal(data.ChatId[:], item.ChatId[:]) { - return item - } - } return nil } -func (s *store) findChatById(chatId chat.ChatId) *chat.ChatRecord { - for _, item := range s.chatRecords { - if bytes.Equal(chatId[:], item.ChatId[:]) { - return item - } +// PutChatMessageV2 stores or updates a message record in a specific chat +func (s *InMemoryStore) PutChatMessageV2(_ context.Context, record *chat.MessageRecord) error { + if err := record.Validate(); err != nil { + return err } - return nil -} -func (s *store) findMember(data *chat.MemberRecord) *chat.MemberRecord { - for _, item := range s.memberRecords { - if data.Id == item.Id { - return item - } + s.mu.Lock() + defer s.mu.Unlock() - if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MemberId[:], item.MemberId[:]) { - return item - } + messages := s.messages[string(record.ChatId[:])] + if messages == nil { + messages = make([]*chat.MessageRecord, 0) + s.messages[string(record.ChatId[:])] = messages } - return nil -} -func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat.MemberRecord { - for _, item := range s.memberRecords { - if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(memberId[:], item.MemberId[:]) { - return item - } + i, found := sort.Find(len(messages), func(i int) int { + return bytes.Compare(record.MessageId[:], messages[i].MessageId[:]) + }) + if found { + return chat.ErrMessageExists } - return nil -} - -func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord { - var res []*chat.MemberRecord - for _, item := range s.memberRecords { - if bytes.Equal(chatId[:], item.ChatId[:]) { - res = append(res, item) - } - } - return res -} -func (s *store) findMembersByPlatformIds(idByPlatform map[chat.Platform]string) []*chat.MemberRecord { - var res []*chat.MemberRecord - for _, item := range s.memberRecords { - platformId, ok := idByPlatform[item.Platform] - if !ok { - continue - } + cloned := record.Clone() + messages = slices.Insert(messages, i, &cloned) + s.messages[string(record.ChatId[:])] = messages - if platformId == item.PlatformId { - res = append(res, item) - } - } - return res + return nil } -func (s *store) getMemberRecordPage(items []*chat.MemberRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MemberRecord, error) { - if len(items) == 0 { - return nil, nil +// SetChatMuteStateV2 sets the mute state for a specific chat member +func (s *InMemoryStore) SetChatMuteStateV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, isMuted bool) error { + if err := chatId.Validate(); err != nil { + return err } - - var memberIdCursor *uint64 - if len(cursor) > 0 { - cursorValue := query.FromCursor(cursor) - memberIdCursor = &cursorValue + if err := memberId.Validate(); err != nil { + return err } - var res []*chat.MemberRecord - if memberIdCursor == nil { - res = items - } else { - for _, item := range items { - if item.Id > int64(*memberIdCursor) && direction == query.Ascending { - res = append(res, item) - } + s.mu.Lock() + defer s.mu.Unlock() - if item.Id < int64(*memberIdCursor) && direction == query.Descending { - res = append(res, item) - } + if members, exists := s.members[string(chatId[:])]; exists { + if member, exists := members[memberId.String()]; exists { + member.IsMuted = isMuted + return nil } } + return chat.ErrMemberNotFound +} - if direction == query.Ascending { - sort.Sort(chat.MembersById(res)) - } else { - sort.Sort(sort.Reverse(chat.MembersById(res))) +// AdvanceChatPointerV2 advances a pointer for a chat member +func (s *InMemoryStore) AdvanceChatPointerV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, pointerType chat.PointerType, pointer chat.MessageId) (bool, error) { + if err := chatId.Validate(); err != nil { + return false, err } - - if len(res) >= int(limit) { - return res[:limit], nil + if err := memberId.Validate(); err != nil { + return false, err + } + if err := pointer.Validate(); err != nil { + return false, err } - return res, nil -} - -func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord { - for _, item := range s.messageRecords { - if data.Id == item.Id { - return item - } + s.mu.Lock() + defer s.mu.Unlock() - if bytes.Equal(data.ChatId[:], item.ChatId[:]) && bytes.Equal(data.MessageId[:], item.MessageId[:]) { - return item - } + members, exists := s.members[string(chatId[:])] + if !exists { + return false, chat.ErrMemberNotFound } - return nil -} -func (s *store) findMessageById(chatId chat.ChatId, messageId chat.MessageId) *chat.MessageRecord { - for _, item := range s.messageRecords { - if bytes.Equal(chatId[:], item.ChatId[:]) && bytes.Equal(messageId[:], item.MessageId[:]) { - return item - } + member, exists := members[memberId.String()] + if !exists { + return false, chat.ErrMemberNotFound } - return nil -} -func (s *store) findMessagesByChatId(chatId chat.ChatId) []*chat.MessageRecord { - var res []*chat.MessageRecord - for _, item := range s.messageRecords { - if bytes.Equal(chatId[:], item.ChatId[:]) { - res = append(res, item) + switch pointerType { + case chat.PointerTypeSent: + case chat.PointerTypeDelivered: + if member.DeliveryPointer == nil || bytes.Compare(pointer[:], member.DeliveryPointer[:]) > 0 { + newPtr := pointer.Clone() + member.DeliveryPointer = &newPtr + return true, nil } - } - return res -} - -func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.MessageId) []*chat.MessageRecord { - var res []*chat.MessageRecord - for _, item := range items { - if item.MessageId.After(pointer) { - res = append(res, item) + case chat.PointerTypeRead: + if member.ReadPointer == nil || bytes.Compare(pointer[:], member.ReadPointer[:]) > 0 { + newPtr := pointer.Clone() + member.ReadPointer = &newPtr + return true, nil } + default: + return false, chat.ErrInvalidPointerType } - return res -} -func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord { - var res []*chat.MessageRecord - for _, item := range items { - if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) { - res = append(res, item) - } - } - return res + return false, nil } -func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord { - var res []*chat.MessageRecord - for _, item := range items { - if !item.IsSilent { - res = append(res, item) - } +// GetChatUnreadCountV2 calculates and returns the unread message count +func (s *InMemoryStore) GetChatUnreadCountV2(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer *chat.MessageId) (uint32, error) { + if err := chatId.Validate(); err != nil { + return 0, err } - return res -} - -func (s *store) getMessageRecordPage(items []*chat.MessageRecord, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) { - if len(items) == 0 { - return nil, nil + if err := memberId.Validate(); err != nil { + return 0, err } - - var messageIdCursor *chat.MessageId - if len(cursor) > 0 { - messageId, err := chat.GetMessageIdFromBytes(cursor) - if err != nil { - return nil, err + if readPointer != nil { + if err := readPointer.Validate(); err != nil { + return 0, err } - messageIdCursor = &messageId } - var res []*chat.MessageRecord - if messageIdCursor == nil { - res = items - } else { - for _, item := range items { - if item.MessageId.After(*messageIdCursor) && direction == query.Ascending { - res = append(res, item) - } + s.mu.RLock() + defer s.mu.RUnlock() - if item.MessageId.Before(*messageIdCursor) && direction == query.Descending { - res = append(res, item) + unread := uint32(0) + messages := s.messages[string(chatId[:])] + for _, message := range messages { + if readPointer != nil { + if bytes.Compare(message.MessageId[:], readPointer[:]) <= 0 { + continue } } - } - - if direction == query.Ascending { - sort.Sort(chat.MessagesByMessageId(res)) - } else { - sort.Sort(sort.Reverse(chat.MessagesByMessageId(res))) - } - if len(res) >= int(limit) { - return res[:limit], nil - } - - return res, nil -} - -func (s *store) reset() { - s.mu.Lock() - defer s.mu.Unlock() - - s.chatRecords = nil - s.memberRecords = nil - s.messageRecords = nil - - s.lastChatId = 0 - s.lastMemberId = 0 - s.lastMessageId = 0 -} + if message.Sender.String() == memberId.String() { + continue + } -func cloneMemberRecords(items []*chat.MemberRecord) []*chat.MemberRecord { - res := make([]*chat.MemberRecord, len(items)) - for i, item := range items { - cloned := item.Clone() - res[i] = &cloned + unread++ } - return res -} -func cloneMessageRecords(items []*chat.MessageRecord) []*chat.MessageRecord { - res := make([]*chat.MessageRecord, len(items)) - for i, item := range items { - cloned := item.Clone() - res[i] = &cloned - } - return res + return unread, nil } diff --git a/pkg/code/data/chat/v2/memory/store_test.go b/pkg/code/data/chat/v2/memory/store_test.go index cd61dfa4..48f10ff1 100644 --- a/pkg/code/data/chat/v2/memory/store_test.go +++ b/pkg/code/data/chat/v2/memory/store_test.go @@ -1,15 +1,535 @@ package memory import ( + "bytes" + "context" + "fmt" + "math/rand/v2" + "slices" "testing" + "time" - "github.com/code-payments/code-server/pkg/code/data/chat/v2/tests" + "github.com/stretchr/testify/require" + + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/pointer" ) -func TestChatMemoryStore(t *testing.T) { - testStore := New() - teardown := func() { - testStore.(*store).reset() +func TestInMemoryStore_GetChatMetadata(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + metadata := &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + ChatTitle: pointer.String("hello"), + } + + result, err := store.GetChatMetadata(context.Background(), chatId) + require.ErrorIs(t, err, chat.ErrChatNotFound) + require.Nil(t, result) + + require.NoError(t, store.PutChatV2(context.Background(), metadata)) + require.ErrorIs(t, store.PutChatV2(context.Background(), metadata), chat.ErrChatExists) + + result, err = store.GetChatMetadata(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, metadata.Clone(), result.Clone()) +} + +func TestInMemoryStore_GetAllChatsForUserV2(t *testing.T) { + store := New() + + memberId := chat.MemberId("user123") + + chatIds, err := store.GetAllChatsForUserV2(context.Background(), memberId) + require.NoError(t, err) + require.Empty(t, chatIds) + + var expectedChatIds []chat.ChatId + for i := 0; i < 10; i++ { + chatId := chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + expectedChatIds = append(expectedChatIds, chatId) + + require.NoError(t, store.PutChatV2(context.Background(), &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + })) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + require.ErrorIs(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + }), chat.ErrMemberExists) + } + + chatIds, err = store.GetAllChatsForUserV2(context.Background(), memberId) + require.NoError(t, err) + require.Equal(t, expectedChatIds, chatIds) +} + +func TestInMemoryStore_GetAllChatsForUserV2_Pagination(t *testing.T) { + store := New() + + memberId := chat.MemberId("user123") + + // Create 10 chats + var chatIds []chat.ChatId + for i := 0; i < 10; i++ { + chatId := chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + chatIds = append(chatIds, chatId) + + require.NoError(t, store.PutChatV2(context.Background(), &chat.MetadataRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + })) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + } + + reversedChatIds := slices.Clone(chatIds) + slices.Reverse(reversedChatIds) + + t.Run("Ascending Order", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Ascending)) + require.NoError(t, err) + require.Equal(t, chatIds, result) + }) + + t.Run("Descending Order", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending)) + require.NoError(t, err) + require.Equal(t, reversedChatIds, result) + }) + + t.Run("With Cursor", func(t *testing.T) { + cursor := chatIds[3][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Ascending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, chatIds[4:], result) + }) + + t.Run("With Cursor (Descending)", func(t *testing.T) { + cursor := reversedChatIds[6][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, reversedChatIds[7:], result) + }) + + t.Run("With Limit", func(t *testing.T) { + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithLimit(5)) + require.NoError(t, err) + require.Equal(t, chatIds[:5], result) + }) + + t.Run("With Limit (Descending)", func(t *testing.T) { + cursor := reversedChatIds[4][:] + result, err := store.GetAllChatsForUserV2(context.Background(), memberId, query.WithDirection(query.Descending), query.WithCursor(cursor), query.WithLimit(3)) + require.NoError(t, err) + require.Equal(t, reversedChatIds[5:8], result) + }) +} + +func TestInMemoryStore_GetChatMessageV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + messageId := chat.GenerateMessageId() + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + Payload: []byte("payload"), + } + + err := store.PutChatMessageV2(context.Background(), message) + require.NoError(t, err) + + result, err := store.GetChatMessageV2(context.Background(), chatId, messageId) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, bytes.Equal(result.MessageId[:], messageId[:])) +} + +// TODO: Need proper pagination tests +func TestInMemoryStore_GetAllChatMessagesV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMessages []*chat.MessageRecord + for i := 0; i < 10; i++ { + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Payload: []byte(fmt.Sprintf("payload-%d", i)), + } + expectedMessages = append(expectedMessages, message) + + // TODO: We might need a way to address this longer term. + time.Sleep(time.Millisecond) + + require.NoError(t, store.PutChatMessageV2(context.Background(), message)) + require.ErrorIs(t, store.PutChatMessageV2(context.Background(), message), chat.ErrMessageExists) + } + + isSorted := slices.IsSortedFunc(expectedMessages, func(a, b *chat.MessageRecord) int { + return bytes.Compare(a.MessageId[:], b.MessageId[:]) + }) + require.True(t, isSorted) + + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, len(expectedMessages), len(messages)) + + for i := 0; i < len(messages); i++ { + require.Equal(t, expectedMessages[i].ChatId, messages[i].ChatId) + require.Equal(t, expectedMessages[i].MessageId, messages[i].MessageId) + require.Equal(t, expectedMessages[i].Sender, messages[i].Sender) + require.Equal(t, expectedMessages[i].Payload, messages[i].Payload) + require.Equal(t, expectedMessages[i].IsSilent, messages[i].IsSilent) + } +} + +func TestInMemoryStore_GetAllChatMessagesV2_Pagination(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMessages []*chat.MessageRecord + for i := 0; i < 10; i++ { + message := &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Payload: []byte(fmt.Sprintf("payload-%d", i)), + } + expectedMessages = append(expectedMessages, message) + time.Sleep(time.Millisecond) + require.NoError(t, store.PutChatMessageV2(context.Background(), message)) + } + + reversedMessages := slices.Clone(expectedMessages) + slices.Reverse(reversedMessages) + + t.Run("Ascending order", func(t *testing.T) { + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending)) + require.NoError(t, err) + require.Equal(t, expectedMessages, messages) + }) + + t.Run("Descending order", func(t *testing.T) { + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Descending)) + require.NoError(t, err) + require.Equal(t, reversedMessages, messages) + }) + + t.Run("With limit", func(t *testing.T) { + limit := uint64(5) + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending), query.WithLimit(limit)) + require.NoError(t, err) + require.Equal(t, expectedMessages[:limit], messages) + }) + + t.Run("With cursor", func(t *testing.T) { + cursor := expectedMessages[3].MessageId[:] + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Ascending), query.WithCursor(cursor)) + require.NoError(t, err) + require.Equal(t, expectedMessages[4:], messages) + }) + + t.Run("With cursor and limit", func(t *testing.T) { + cursor := reversedMessages[3].MessageId[:] + limit := uint64(3) + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId, query.WithDirection(query.Descending), query.WithCursor(cursor), query.WithLimit(limit)) + require.NoError(t, err) + require.Equal(t, reversedMessages[4:7], messages) + }) +} + +// TODO: Need proper pagination tests +func TestInMemoryStore_GetChatMembersV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + + var expectedMembers []*chat.MemberRecord + for i := 0; i < 10; i++ { + member := &chat.MemberRecord{ + ChatId: chatId, + MemberId: fmt.Sprintf("user%d", i), + Owner: fmt.Sprintf("owner%d", i), + Platform: chat.PlatformTwitter, + PlatformId: fmt.Sprintf("twitter%d", i), + IsMuted: true, + JoinedAt: time.Now(), + } + + dPtr := chat.GenerateMessageId() + time.Sleep(time.Millisecond) + rPtr := chat.GenerateMessageId() + + member.DeliveryPointer = &dPtr + member.ReadPointer = &rPtr + + expectedMembers = append(expectedMembers, member) + + require.NoError(t, store.PutChatMemberV2(context.Background(), member)) + require.ErrorIs(t, store.PutChatMemberV2(context.Background(), member), chat.ErrMemberExists) + } + + members, err := store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.Equal(t, expectedMembers, members) +} + +func TestInMemoryStore_IsChatMember(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + isMember, err := store.IsChatMember(context.Background(), chatId, memberId) + require.NoError(t, err) + require.False(t, isMember) + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + })) + + isMember, err = store.IsChatMember(context.Background(), chatId, memberId) + require.NoError(t, err) + require.True(t, isMember) +} + +func TestInMemoryStore_PutChatV2(t *testing.T) { + store := New() + + for i, expected := range []*chat.MetadataRecord{ + { + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + }, + { + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + ChatTitle: pointer.String("hello"), + }, + } { + expected.ChatId = chat.ChatId(bytes.Repeat([]byte{byte(i)}, 32)) + + require.NoError(t, store.PutChatV2(context.Background(), expected)) + + other := expected.Clone() + other.ChatTitle = pointer.String("mutated") + require.ErrorIs(t, store.PutChatV2(context.Background(), &other), chat.ErrChatExists) + + actual, err := store.GetChatMetadata(context.Background(), expected.ChatId) + require.NoError(t, err) + require.Equal(t, expected, actual) + } + + for _, invalid := range []*chat.MetadataRecord{ + {}, + { + ChatId: chat.ChatId{1, 2, 3}, + }, + { + ChatId: chat.ChatId{1, 2, 3}, + CreatedAt: time.Now(), + }, + { + ChatId: chat.ChatId{1, 2, 3}, + ChatType: chat.ChatTypeTwoWay, + }, + } { + require.Error(t, store.PutChatV2(context.Background(), invalid)) } - tests.RunTests(t, testStore, teardown) +} + +func TestInMemoryStore_SetChatMuteStateV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "user", + MemberId: memberId.String(), + JoinedAt: time.Now(), + })) + + members, err := store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.False(t, members[0].IsMuted) + + require.NoError(t, store.SetChatMuteStateV2(context.Background(), chatId, memberId, true)) + + members, err = store.GetChatMembersV2(context.Background(), chatId) + require.NoError(t, err) + require.True(t, members[0].IsMuted) +} + +func TestInMemoryStore_GetChatUnreadCountV2(t *testing.T) { + store := New() + + // Create multiple chats + chats := []chat.ChatId{ + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9}, + } + counts := []int{0, 0, 0} + + ourMemberId := chat.MemberId("our_user") + otherMemberId := chat.MemberId("other_user") + + for chatIdx, chatId := range chats { + // Add members to the chat + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "our_user", + MemberId: ourMemberId.String(), + JoinedAt: time.Now(), + })) + require.NoError(t, store.PutChatMemberV2(context.Background(), &chat.MemberRecord{ + ChatId: chatId, + Platform: chat.PlatformTwitter, + PlatformId: "other_user", + MemberId: otherMemberId.String(), + JoinedAt: time.Now(), + })) + + // Generate N messages for each chat + N := 10 + for i := 0; i < N; i++ { + sender := ourMemberId + if rand.IntN(100) < 50 { // Approximately 50% chance for a message to be from the other user + sender = otherMemberId + counts[chatIdx]++ + } + + require.NoError(t, store.PutChatMessageV2(context.Background(), &chat.MessageRecord{ + ChatId: chatId, + MessageId: chat.GenerateMessageId(), + Sender: &sender, + Payload: []byte(fmt.Sprintf("Message %d for chat %v", i, chatId)), + })) + + time.Sleep(time.Millisecond) + } + } + + // Verify that each chat has a distinct unread count + for chatIdx, chatId := range chats { + ptr := chat.GenerateMessageIdAtTime(time.Now().Add(-time.Hour)) + count, err := store.GetChatUnreadCountV2(context.Background(), chatId, ourMemberId, &ptr) + require.NoError(t, err) + require.EqualValues(t, counts[chatIdx], count) + + if count == 0 { + continue + } + + messages, err := store.GetAllChatMessagesV2(context.Background(), chatId) + require.NoError(t, err) + + var offset *chat.MessageId + for _, message := range messages { + if message.Sender != nil && !bytes.Equal(*message.Sender, ourMemberId) { + offset = &message.MessageId + break + } + } + require.NotNil(t, offset) + + newCount, err := store.GetChatUnreadCountV2(context.Background(), chatId, ourMemberId, offset) + require.NoError(t, err) + require.Equal(t, count-1, newCount) + } +} + +func TestInMemoryStore_AdvanceChatPointerV2(t *testing.T) { + store := New() + + chatId := chat.ChatId{1, 2, 3} + memberId := chat.MemberId("user123") + + // Create a chat and add a member + metadata := &chat.MetadataRecord{ + Id: 0, + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: time.Now(), + } + require.NoError(t, store.PutChatV2(context.Background(), metadata)) + + member := &chat.MemberRecord{ + ChatId: chatId, + MemberId: memberId.String(), + Platform: chat.PlatformTwitter, + PlatformId: "user", + JoinedAt: time.Now(), + + DeliveryPointer: nil, + ReadPointer: nil, + } + require.NoError(t, store.PutChatMemberV2(context.Background(), member)) + + // Test advancing delivery pointer + message1 := chat.GenerateMessageId() + advanced, err := store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeDelivered, message1) + require.NoError(t, err) + require.True(t, advanced) + + // Test advancing read pointer + message2 := chat.GenerateMessageId() + advanced, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeRead, message2) + require.NoError(t, err) + require.True(t, advanced) + + // Test advancing to an earlier message (should not advance) + advanced, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerTypeDelivered, message1) + require.NoError(t, err) + require.False(t, advanced) + + // Test with invalid pointer type + _, err = store.AdvanceChatPointerV2(context.Background(), chatId, memberId, chat.PointerType(8), message2) + require.ErrorIs(t, err, chat.ErrInvalidPointerType) + + // Test with non-existent chat + nonExistentChatId := chat.ChatId{4, 5, 6} + _, err = store.AdvanceChatPointerV2(context.Background(), nonExistentChatId, memberId, chat.PointerTypeDelivered, message2) + require.ErrorIs(t, err, chat.ErrMemberNotFound) + + // Test with non-existent member + nonExistentMemberId := chat.MemberId("nonexistent") + _, err = store.AdvanceChatPointerV2(context.Background(), chatId, nonExistentMemberId, chat.PointerTypeDelivered, message2) + require.ErrorIs(t, err, chat.ErrMemberNotFound) } diff --git a/pkg/code/data/chat/v2/model.go b/pkg/code/data/chat/v2/model.go index bae36478..dd80ff2a 100644 --- a/pkg/code/data/chat/v2/model.go +++ b/pkg/code/data/chat/v2/model.go @@ -1,9 +1,9 @@ package chat_v2 import ( + "fmt" "time" - "github.com/mr-tron/base58" "github.com/pkg/errors" "github.com/code-payments/code-server/pkg/pointer" @@ -15,122 +15,13 @@ type ChatType uint8 const ( ChatTypeUnknown ChatType = iota - ChatTypeNotification ChatTypeTwoWay // ChatTypeGroup ) -type ReferenceType uint8 - -const ( - ReferenceTypeUnknown ReferenceType = iota - ReferenceTypeIntent - ReferenceTypeSignature -) - -type PointerType uint8 - -const ( - PointerTypeUnknown PointerType = iota - PointerTypeSent - PointerTypeDelivered - PointerTypeRead -) - -type Platform uint8 - -const ( - PlatformUnknown Platform = iota - PlatformCode - PlatformTwitter -) - -type ChatRecord struct { - Id int64 - ChatId ChatId - - ChatType ChatType - - // Presence determined by ChatType: - // * Notification: Present, and may be a localization key - // * Two Way: Not present and generated dynamically based on chat members - // * Group: Present, and will not be a localization key - ChatTitle *string - - IsVerified bool - - CreatedAt time.Time -} - -type MemberRecord struct { - Id int64 - ChatId ChatId - MemberId MemberId - - Platform Platform - PlatformId string - - // If Platform != PlatformCode, this store the owner - // of the account (at time of creation). This allows - // us to send push notifications for non-code users. - OwnerAccount string - - DeliveryPointer *MessageId - ReadPointer *MessageId - - IsMuted bool - IsUnsubscribed bool - - JoinedAt time.Time -} - -func (m *MemberRecord) GetOwner() string { - if m.Platform == PlatformCode { - return m.PlatformId - } - - return m.OwnerAccount -} - -type MessageRecord struct { - Id int64 - ChatId ChatId - MessageId MessageId - - // Not present for notification-style chats - Sender *MemberId - - Data []byte - - ReferenceType *ReferenceType - Reference *string - - IsSilent bool - - // Note: No timestamp field, since it's encoded in MessageId -} - -type MembersById []*MemberRecord - -func (a MembersById) Len() int { return len(a) } -func (a MembersById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a MembersById) Less(i, j int) bool { - return a[i].Id < a[j].Id -} - -type MessagesByMessageId []*MessageRecord - -func (a MessagesByMessageId) Len() int { return len(a) } -func (a MessagesByMessageId) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a MessagesByMessageId) Less(i, j int) bool { - return a[i].MessageId.Before(a[j].MessageId) -} - // GetChatTypeFromProto gets a chat type from the protobuf variant func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { switch proto { - case chatpb.ChatType_NOTIFICATION: - return ChatTypeNotification case chatpb.ChatType_TWO_WAY: return ChatTypeTwoWay default: @@ -141,8 +32,6 @@ func GetChatTypeFromProto(proto chatpb.ChatType) ChatType { // ToProto returns the proto representation of the chat type func (c ChatType) ToProto() chatpb.ChatType { switch c { - case ChatTypeNotification: - return chatpb.ChatType_NOTIFICATION case ChatTypeTwoWay: return chatpb.ChatType_TWO_WAY default: @@ -153,8 +42,6 @@ func (c ChatType) ToProto() chatpb.ChatType { // String returns the string representation of the chat type func (c ChatType) String() string { switch c { - case ChatTypeNotification: - return "notification" case ChatTypeTwoWay: return "two-way" default: @@ -162,6 +49,15 @@ func (c ChatType) String() string { } } +type PointerType uint8 + +const ( + PointerTypeUnknown PointerType = iota + PointerTypeSent + PointerTypeDelivered + PointerTypeRead +) + // GetPointerTypeFromProto gets a chat ID from the protobuf variant func GetPointerTypeFromProto(proto chatpb.PointerType) PointerType { switch proto { @@ -204,7 +100,14 @@ func (p PointerType) String() string { } } -// ToProto returns the proto representation of the platform +type Platform uint8 + +const ( + PlatformUnknown Platform = iota + PlatformTwitter +) + +// GetPlatformFromProto returns the proto representation of the platform func GetPlatformFromProto(proto chatpb.Platform) Platform { switch proto { case chatpb.Platform_TWITTER: @@ -227,8 +130,6 @@ func (p Platform) ToProto() chatpb.Platform { // String returns the string representation of the platform func (p Platform) String() string { switch p { - case PlatformCode: - return "code" case PlatformTwitter: return "twitter" default: @@ -236,21 +137,23 @@ func (p Platform) String() string { } } +type MetadataRecord struct { + Id int64 + ChatId ChatId + ChatType ChatType + CreatedAt time.Time + + ChatTitle *string +} + // Validate validates a chat Record -func (r *ChatRecord) Validate() error { +func (r *MetadataRecord) Validate() error { if err := r.ChatId.Validate(); err != nil { return errors.Wrap(err, "invalid chat id") } switch r.ChatType { - case ChatTypeNotification: - if r.ChatTitle == nil || len(*r.ChatTitle) == 0 { - return errors.New("chat title is required for notification chats") - } case ChatTypeTwoWay: - if r.ChatTitle != nil { - return errors.New("chat title cannot be set for two way chats") - } default: return errors.Errorf("invalid chat type: %d", r.ChatType) } @@ -263,62 +166,71 @@ func (r *ChatRecord) Validate() error { } // Clone clones a chat record -func (r *ChatRecord) Clone() ChatRecord { - return ChatRecord{ - Id: r.Id, - ChatId: r.ChatId, - - ChatType: r.ChatType, +func (r *MetadataRecord) Clone() MetadataRecord { + return MetadataRecord{ + Id: r.Id, + ChatId: r.ChatId, + ChatType: r.ChatType, + CreatedAt: r.CreatedAt, ChatTitle: pointer.StringCopy(r.ChatTitle), - - IsVerified: r.IsVerified, - - CreatedAt: r.CreatedAt, } } // CopyTo copies a chat record to the provided destination -func (r *ChatRecord) CopyTo(dst *ChatRecord) { +func (r *MetadataRecord) CopyTo(dst *MetadataRecord) { dst.Id = r.Id dst.ChatId = r.ChatId - dst.ChatType = r.ChatType + dst.CreatedAt = r.CreatedAt dst.ChatTitle = pointer.StringCopy(r.ChatTitle) +} - dst.IsVerified = r.IsVerified +type MemberRecord struct { + Id int64 + ChatId ChatId - dst.CreatedAt = r.CreatedAt + // MemberId is derived from Owner (using account.ToMessagingAccount) + // + // It is stored to allow indexed lookups when only MemberId is available. + // We must also store Owner so server can lookup proper push tokens. + MemberId string + + // Owner is required to be able to send push notifications. + // + // Currently, it is _optional_, as we don't have a way to reverse lookup. + // However, we _will_ want to make it mandatory. + Owner string + + // Identity. + // + // Currently, assumes single. + Platform Platform + PlatformId string + + DeliveryPointer *MessageId + ReadPointer *MessageId + + IsMuted bool + JoinedAt time.Time } // Validate validates a member Record func (r *MemberRecord) Validate() error { if err := r.ChatId.Validate(); err != nil { - return errors.Wrap(err, "invalid chat id") + return fmt.Errorf("invalid chat id: %w", err) } - if err := r.MemberId.Validate(); err != nil { - return errors.Wrap(err, "invalid member id") + if len(r.MemberId) == 0 { + return fmt.Errorf("missing member id") } if len(r.PlatformId) == 0 { - return errors.New("platform id is required") - } - if r.Platform != PlatformCode && len(r.OwnerAccount) == 0 { - return errors.New("owner account is required for non code platform members") + return fmt.Errorf("missing platform id") } switch r.Platform { - case PlatformCode: - decoded, err := base58.Decode(r.PlatformId) - if err != nil { - return errors.Wrap(err, "invalid base58 plaftorm id") - } - - if len(decoded) != 32 { - return errors.Wrap(err, "platform id is not a 32 byte buffer") - } case PlatformTwitter: if len(r.PlatformId) > 15 { return errors.New("platform id must have at most 15 characters") @@ -364,17 +276,15 @@ func (r *MemberRecord) Clone() MemberRecord { Id: r.Id, ChatId: r.ChatId, MemberId: r.MemberId, + Owner: r.Owner, - Platform: r.Platform, - PlatformId: r.PlatformId, - OwnerAccount: r.OwnerAccount, + Platform: r.Platform, + PlatformId: r.PlatformId, DeliveryPointer: deliveryPointerCopy, ReadPointer: readPointerCopy, - IsMuted: r.IsMuted, - IsUnsubscribed: r.IsUnsubscribed, - + IsMuted: r.IsMuted, JoinedAt: r.JoinedAt, } } @@ -383,11 +293,11 @@ func (r *MemberRecord) Clone() MemberRecord { func (r *MemberRecord) CopyTo(dst *MemberRecord) { dst.Id = r.Id dst.ChatId = r.ChatId + dst.Owner = r.Owner dst.MemberId = r.MemberId dst.Platform = r.Platform dst.PlatformId = r.PlatformId - dst.OwnerAccount = r.OwnerAccount if r.DeliveryPointer != nil { cloned := r.DeliveryPointer.Clone() @@ -399,11 +309,32 @@ func (r *MemberRecord) CopyTo(dst *MemberRecord) { } dst.IsMuted = r.IsMuted - dst.IsUnsubscribed = r.IsUnsubscribed - dst.JoinedAt = r.JoinedAt } +type MembersById []*MemberRecord + +func (a MembersById) Len() int { return len(a) } +func (a MembersById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MembersById) Less(i, j int) bool { + return a[i].Id < a[j].Id +} + +type MessageRecord struct { + Id int64 + ChatId ChatId + MessageId MessageId + + Sender *MemberId + + Payload []byte + + IsSilent bool + + // Note: No timestamp field, since it's encoded in MessageId + // Note: Maybe a timestamp field, because it's maybe better? +} + // Validate validates a message Record func (r *MessageRecord) Validate() error { if err := r.ChatId.Validate(); err != nil { @@ -420,44 +351,19 @@ func (r *MessageRecord) Validate() error { } } - if len(r.Data) == 0 { - return errors.New("message data is required") - } - - if r.Reference == nil && r.ReferenceType != nil { - return errors.New("reference is required when reference type is provided") - } - - if r.Reference != nil && r.ReferenceType == nil { - return errors.New("reference cannot be set when reference type is missing") + if len(r.Payload) == 0 { + return errors.New("message payload is required") } - if r.ReferenceType != nil { - switch *r.ReferenceType { - case ReferenceTypeIntent: - decoded, err := base58.Decode(*r.Reference) - if err != nil { - return errors.Wrap(err, "invalid base58 intent id reference") - } - - if len(decoded) != 32 { - return errors.Wrap(err, "reference is not a 32 byte buffer") - } - case ReferenceTypeSignature: - decoded, err := base58.Decode(*r.Reference) - if err != nil { - return errors.Wrap(err, "invalid base58 signature reference") - } + return nil +} - if len(decoded) != 64 { - return errors.Wrap(err, "reference is not a 64 byte buffer") - } - default: - return errors.Errorf("invalid reference type: %d", *r.ReferenceType) - } - } +type MessagesByMessageId []*MessageRecord - return nil +func (a MessagesByMessageId) Len() int { return len(a) } +func (a MessagesByMessageId) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a MessagesByMessageId) Less(i, j int) bool { + return a[i].MessageId.Before(a[j].MessageId) } // Clone clones a message record @@ -468,13 +374,10 @@ func (r *MessageRecord) Clone() MessageRecord { senderCopy = &cloned } - dataCopy := make([]byte, len(r.Data)) - copy(dataCopy, r.Data) - - var referenceTypeCopy *ReferenceType - if r.ReferenceType != nil { - cloned := *r.ReferenceType - referenceTypeCopy = &cloned + var payloadCopy []byte + if len(r.Payload) > 0 { + payloadCopy = make([]byte, len(r.Payload)) + copy(payloadCopy, r.Payload) } return MessageRecord{ @@ -484,10 +387,7 @@ func (r *MessageRecord) Clone() MessageRecord { Sender: senderCopy, - Data: dataCopy, - - ReferenceType: referenceTypeCopy, - Reference: pointer.StringCopy(r.Reference), + Payload: payloadCopy, IsSilent: r.IsSilent, } @@ -504,15 +404,9 @@ func (r *MessageRecord) CopyTo(dst *MessageRecord) { dst.Sender = &cloned } - dataCopy := make([]byte, len(r.Data)) - copy(dataCopy, r.Data) - dst.Data = dataCopy - - if r.ReferenceType != nil { - cloned := *r.ReferenceType - dst.ReferenceType = &cloned - } - dst.Reference = pointer.StringCopy(r.Reference) + payloadCopy := make([]byte, len(r.Payload)) + copy(payloadCopy, r.Payload) + dst.Payload = payloadCopy dst.IsSilent = r.IsSilent } diff --git a/pkg/code/data/chat/v2/store.go b/pkg/code/data/chat/v2/store.go index 2029f6db..fa4fbb6e 100644 --- a/pkg/code/data/chat/v2/store.go +++ b/pkg/code/data/chat/v2/store.go @@ -3,88 +3,68 @@ package chat_v2 import ( "context" "errors" - - "github.com/code-payments/code-protobuf-api/generated/go/common/v1" - "github.com/code-payments/code-server/pkg/database/query" ) var ( - ErrChatExists = errors.New("chat already exists") - ErrChatNotFound = errors.New("chat not found") - ErrMemberExists = errors.New("chat member already exists") - ErrMemberNotFound = errors.New("chat member not found") - ErrMemberIdentityAlreadyUpgraded = errors.New("chat member identity already upgraded") - ErrMessageExsits = errors.New("chat message already exists") - ErrMessageNotFound = errors.New("chat message not found") - ErrInvalidPointerType = errors.New("invalid pointer type") + ErrChatExists = errors.New("chat already exists") + ErrChatNotFound = errors.New("chat not found") + ErrMemberExists = errors.New("chat member already exists") + ErrMemberNotFound = errors.New("chat member not found") + ErrMessageExists = errors.New("chat message already exists") + ErrMessageNotFound = errors.New("chat message not found") + ErrInvalidPointerType = errors.New("invalid pointer type") ) -// todo: Define interface methods type Store interface { - // GetChatById gets a chat by its chat ID - GetChatById(ctx context.Context, chatId ChatId) (*ChatRecord, error) - - // GetMemberById gets a chat member by the chat and member IDs - GetMemberById(ctx context.Context, chatId ChatId, memberId MemberId) (*MemberRecord, error) - - // GetMessageById gets a chat message by the chat and message IDs - GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) - - // GetAllMembersByChatId gets all members for a given chat + // GetChatMetadata retrieves the metadata record for a specific chat, identified by chatId. // - // todo: Add paging when we introduce group chats - GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) - - // GetAllMembersByPlatformIds gets all members for platform users across all chats - GetAllMembersByPlatformIds(ctx context.Context, idByPlatform map[Platform]string, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MemberRecord, error) + // It returns ErrChatNotFound if the chat doesn't exist. + GetChatMetadata(ctx context.Context, chatId ChatId) (*MetadataRecord, error) - // GetAllMessagesByChatId gets all messages for a given chat - // - // Note: Cursor is a message ID - GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error) + // GetChatMessageV2 retrieves a specific message from a chat, identified by chatId and messageId. + GetChatMessageV2(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error) - // GetUnreadCount gets the unread message count for a chat ID at a read pointer - GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error) + // GetAllChatsForUserV2 retrieves all chat IDs that a given user (where user is the messaging address). + GetAllChatsForUserV2(ctx context.Context, user MemberId, opts ...query.Option) ([]ChatId, error) - // PutChat creates a new chat - PutChat(ctx context.Context, record *ChatRecord) error + // GetAllChatMessagesV2 retrieves all messages for a specific chat, identified by chatId. + GetAllChatMessagesV2(ctx context.Context, chatId ChatId, opts ...query.Option) ([]*MessageRecord, error) - // PutMember creates a new chat member - PutMember(ctx context.Context, record *MemberRecord) error + // GetChatMembersV2 retrieves all members of a specific chat, identified by chatId. + GetChatMembersV2(ctx context.Context, chatId ChatId) ([]*MemberRecord, error) - // PutMessage creates a new chat message - PutMessage(ctx context.Context, record *MessageRecord) error + // IsChatMember checks if a given member, identified by memberId, is part of a specific chat, identified by chatId. + IsChatMember(ctx context.Context, chatId ChatId, memberId MemberId) (bool, error) - // AdvancePointer advances a chat pointer for a chat member - AdvancePointer(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) - - // UpgradeIdentity upgrades a chat member's identity from an anonymous state - UpgradeIdentity(ctx context.Context, chatId ChatId, memberId MemberId, platform Platform, platformId string) error - - // SetMuteState updates the mute state for a chat member - SetMuteState(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error + // PutChatV2 stores or updates the metadata for a specific chat. + // + // ErrChatExists is returned if the chat with the same ID already exists. + PutChatV2(ctx context.Context, record *MetadataRecord) error - // SetSubscriptionState updates the subscription state for a chat member - SetSubscriptionState(ctx context.Context, chatId ChatId, memberId MemberId, isSubscribed bool) error -} + // PutChatMemberV2 stores or updates a member record for a specific chat. + // + // ErrMemberExists is returned if the member already exists. + // Updating should be done with specific DB calls. + PutChatMemberV2(ctx context.Context, record *MemberRecord) error -type PaymentStore interface { - // MarkFriendshipPaid marks a friendship as paid. + // PutChatMessageV2 stores or updates a message record in a specific chat. // - // The intentId is the intent that paid for the friendship. - MarkFriendshipPaid(ctx context.Context, payer, other *common.SolanaAccountId, intentId *common.IntentId) error + // ErrMessageExists is returned if the message already exists. + PutChatMessageV2(ctx context.Context, record *MessageRecord) error - // IsFriendshipPaid returns whether a payment has been made for a friendship. + // SetChatMuteStateV2 sets the mute state for a specific chat member, identified by chatId and memberId. // - // IsFriendshipPaid is reflexive, with only a single payment being required. - IsFriendshipPaid(ctx context.Context, user, other *common.SolanaAccountId) (bool, error) + // ErrMemberNotFound if the member does not exist. + SetChatMuteStateV2(ctx context.Context, chatId ChatId, memberId MemberId, isMuted bool) error - // MarkChatPaid marks a chat as paid. - MarkChatPaid(ctx context.Context, payer *common.SolanaAccountId, chat ChatId) error + // AdvanceChatPointerV2 advances a pointer for a chat member, identified by chatId and memberId. + // + // It returns whether the pointer was advanced. If no member exists, ErrMemberNotFound is returned. + AdvanceChatPointerV2(ctx context.Context, chatId ChatId, memberId MemberId, pointerType PointerType, pointer MessageId) (bool, error) - // IsChatPaid returns whether a member paid to be part of a chat. + // GetChatUnreadCountV2 calculates and returns the unread message count for a specific chat member, // - // This is only valid for non-two way chats. - IsChatPaid(ctx context.Context, chatId ChatId, member *common.SolanaAccountId) (bool, error) + // Existence checks are not performed. + GetChatUnreadCountV2(ctx context.Context, chatId ChatId, memberId MemberId, readPointer *MessageId) (uint32, error) } diff --git a/pkg/code/data/chat/v2/tests/tests.go b/pkg/code/data/chat/v2/tests/tests.go deleted file mode 100644 index 94c85a89..00000000 --- a/pkg/code/data/chat/v2/tests/tests.go +++ /dev/null @@ -1,14 +0,0 @@ -package tests - -import ( - "testing" - - chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" -) - -func RunTests(t *testing.T, s chat.Store, teardown func()) { - for _, tf := range []func(t *testing.T, s chat.Store){} { - tf(t, s) - teardown() - } -} diff --git a/pkg/code/data/internal.go b/pkg/code/data/internal.go index 342bc3c6..758e1b5c 100644 --- a/pkg/code/data/internal.go +++ b/pkg/code/data/internal.go @@ -396,20 +396,18 @@ type DatabaseData interface { // Chat V2 // -------------------------------------------------------------------------------- - GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) - GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) - GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) - GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) - GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) + GetChatMetadata(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.MetadataRecord, error) + GetChatMessageV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) + GetAllChatsForUserV2(ctx context.Context, user chat_v2.MemberId, opts ...query.Option) ([]chat_v2.ChatId, error) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) - GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) - PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error + GetChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) + IsChatMember(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (bool, error) + PutChatV2(ctx context.Context, record *chat_v2.MetadataRecord) error PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error - AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) - UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error - SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error + AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) + GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer *chat_v2.MessageId) (uint32, error) // Badge Count // -------------------------------------------------------------------------------- @@ -1467,55 +1465,41 @@ func (dp *DatabaseProvider) SetChatSubscriptionStateV1(ctx context.Context, chat // Chat V2 // -------------------------------------------------------------------------------- -func (dp *DatabaseProvider) GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error) { - return dp.chatv2.GetChatById(ctx, chatId) -} -func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error) { - return dp.chatv2.GetMemberById(ctx, chatId, memberId) +func (dp *DatabaseProvider) GetChatMetadata(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.MetadataRecord, error) { + return dp.chatv2.GetChatMetadata(ctx, chatId) } -func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { - return dp.chatv2.GetMessageById(ctx, chatId, messageId) +func (dp *DatabaseProvider) GetChatMessageV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) { + return dp.chatv2.GetChatMessageV2(ctx, chatId, messageId) } -func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { - return dp.chatv2.GetAllMembersByChatId(ctx, chatId) -} -func (dp *DatabaseProvider) GetPlatformUserChatMembershipV2(ctx context.Context, idByPlatform map[chat_v2.Platform]string, opts ...query.Option) ([]*chat_v2.MemberRecord, error) { - req, err := query.DefaultPaginationHandler(opts...) - if err != nil { - return nil, err - } - return dp.chatv2.GetAllMembersByPlatformIds(ctx, idByPlatform, req.Cursor, req.SortBy, req.Limit) +func (dp *DatabaseProvider) GetAllChatsForUserV2(ctx context.Context, user chat_v2.MemberId, opts ...query.Option) ([]chat_v2.ChatId, error) { + return dp.chatv2.GetAllChatsForUserV2(ctx, user, opts...) } func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) { - req, err := query.DefaultPaginationHandler(opts...) - if err != nil { - return nil, err - } - return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit) + return dp.chatv2.GetAllChatMessagesV2(ctx, chatId, opts...) } -func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) { - return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer) +func (dp *DatabaseProvider) GetChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) { + return dp.chatv2.GetChatMembersV2(ctx, chatId) } -func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error { - return dp.chatv2.PutChat(ctx, record) +func (dp *DatabaseProvider) IsChatMember(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (bool, error) { + return dp.chatv2.IsChatMember(ctx, chatId, memberId) +} +func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.MetadataRecord) error { + return dp.chatv2.PutChatV2(ctx, record) } func (dp *DatabaseProvider) PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error { - return dp.chatv2.PutMember(ctx, record) + return dp.chatv2.PutChatMemberV2(ctx, record) } func (dp *DatabaseProvider) PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error { - return dp.chatv2.PutMessage(ctx, record) -} -func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { - return dp.chatv2.AdvancePointer(ctx, chatId, memberId, pointerType, pointer) -} -func (dp *DatabaseProvider) UpgradeChatMemberIdentityV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, platform chat_v2.Platform, platformId string) error { - return dp.chatv2.UpgradeIdentity(ctx, chatId, memberId, platform, platformId) + return dp.chatv2.PutChatMessageV2(ctx, record) } func (dp *DatabaseProvider) SetChatMuteStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isMuted bool) error { - return dp.chatv2.SetMuteState(ctx, chatId, memberId, isMuted) + return dp.chatv2.SetChatMuteStateV2(ctx, chatId, memberId, isMuted) +} +func (dp *DatabaseProvider) AdvanceChatPointerV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, pointerType chat_v2.PointerType, pointer chat_v2.MessageId) (bool, error) { + return dp.chatv2.AdvanceChatPointerV2(ctx, chatId, memberId, pointerType, pointer) } -func (dp *DatabaseProvider) SetChatSubscriptionStateV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, isSubscribed bool) error { - return dp.chatv2.SetSubscriptionState(ctx, chatId, memberId, isSubscribed) +func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer *chat_v2.MessageId) (uint32, error) { + return dp.chatv2.GetChatUnreadCountV2(ctx, chatId, memberId, readPointer) } // Badge Count diff --git a/pkg/code/push/notifications.go b/pkg/code/push/notifications.go index 5bafd6c1..e7f5ba33 100644 --- a/pkg/code/push/notifications.go +++ b/pkg/code/push/notifications.go @@ -422,7 +422,7 @@ func SendChatMessagePushNotificationV2( chatId chat_v2.ChatId, chatTitle string, owner *common.Account, - chatMessage *chatv2pb.ChatMessage, + chatMessage *chatv2pb.Message, ) error { log := logrus.StandardLogger().WithFields(logrus.Fields{ "method": "SendChatMessagePushNotificationV2", @@ -512,15 +512,6 @@ func SendChatMessagePushNotificationV2( } case *chatv2pb.Content_NaclBox, *chatv2pb.Content_Text: contentToPush = content - case *chatv2pb.Content_ThankYou: - contentToPush = &chatv2pb.Content{ - Type: &chatv2pb.Content_Localized{ - Localized: &chatv2pb.LocalizedContent{ - // todo: localize this - KeyOrText: "🙏 They thanked you for their tip", - }, - }, - } } if contentToPush == nil { diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index cbdc2002..235b5265 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -1,11 +1,12 @@ package chat_v2 import ( + "bytes" "context" "database/sql" - "encoding/base64" "fmt" - "math" + "github.com/code-payments/code-server/pkg/code/data/account" + "github.com/code-payments/code-server/pkg/pointer" "sync" "time" @@ -23,7 +24,6 @@ import ( chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" auth_util "github.com/code-payments/code-server/pkg/code/auth" - chatv2 "github.com/code-payments/code-server/pkg/code/chat" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" @@ -34,7 +34,6 @@ import ( "github.com/code-payments/code-server/pkg/database/query" "github.com/code-payments/code-server/pkg/grpc/client" "github.com/code-payments/code-server/pkg/push" - timelock_token "github.com/code-payments/code-server/pkg/solana/timelock/v1" sync_util "github.com/code-payments/code-server/pkg/sync" ) @@ -87,8 +86,8 @@ func NewChatServer( return s } -// todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*chatpb.GetChatsResponse, error) { + // todo: This will require a lot of optimizations since we iterate and make several DB calls for each chat membership log := s.log.WithField("method", "GetChats") log = client.InjectLoggingMetadata(ctx, log) @@ -125,68 +124,36 @@ func (s *Server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch var cursor query.Cursor if req.Cursor != nil { cursor = req.Cursor.Value - } else { - cursor = query.ToCursor(0) - if direction == query.Descending { - cursor = query.ToCursor(math.MaxInt64 - 1) - } } - myIdentities, err := s.getAllIdentities(ctx, owner) + memberID, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("failure getting identities for owner account") + log.WithError(err).Warn("Failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - // todo: Use a better query that returns chat IDs. This will result in duplicate - // chat results if the user is in the chat multiple times across many identities. - platformUserMemberRecords, err := s.data.GetPlatformUserChatMembershipV2( + chats, err := s.data.GetAllChatsForUserV2( ctx, - myIdentities, + memberID, query.WithCursor(cursor), query.WithDirection(direction), query.WithLimit(limit), ) - if err == chat.ErrMemberNotFound { - return &chatpb.GetChatsResponse{ - Result: chatpb.GetChatsResponse_NOT_FOUND, - }, nil - } else if err != nil { - log.WithError(err).Warn("failure getting chat members for platform user") - return nil, status.Error(codes.Internal, "") - } - - log.WithField("chats", len(platformUserMemberRecords)).Info("Retrieved chatlist for user") - - var protoChats []*chatpb.ChatMetadata - for _, platformUserMemberRecord := range platformUserMemberRecords { - log := log.WithField("chat_id", platformUserMemberRecord.ChatId.String()) - - chatRecord, err := s.data.GetChatByIdV2(ctx, platformUserMemberRecord.ChatId) - if err != nil { - log.WithError(err).Warn("failure getting chat record") - return nil, status.Error(codes.Internal, "") - } - - memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatRecord.ChatId) - if err != nil { - log.WithError(err).Warn("failure getting chat members") - return nil, status.Error(codes.Internal, "") - } - protoChat, err := s.toProtoChat(ctx, chatRecord, memberRecords, myIdentities) + log.WithField("chats", len(chats)).Info("Retrieved chatlist for user") + metadata := make([]*chatpb.Metadata, 0, len(chats)) + for _, id := range chats { + md, err := s.getMetadata(ctx, memberID, id) if err != nil { - log.WithError(err).Warn("failure constructing proto chat message") - return nil, status.Error(codes.Internal, "") + return nil, nil } - protoChat.Cursor = &chatpb.Cursor{Value: query.ToCursor(uint64(platformUserMemberRecord.Id))} - protoChats = append(protoChats, protoChat) + metadata = append(metadata, md) } return &chatpb.GetChatsResponse{ Result: chatpb.GetChatsResponse_OK, - Chats: protoChats, + Chats: metadata, }, nil } @@ -201,19 +168,19 @@ func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid chat id") + log.WithError(err).Warn("failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("chat_id", chatId.String()) + log = log.WithField("member_id", memberId.String()) - memberId, err := chat.GetMemberIdFromProto(req.MemberId) + chatId, err := chat.GetChatIdFromProto(req.ChatId) if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("invalid chat id") return nil, status.Error(codes.Internal, "") } - log = log.WithField("member_id", memberId.String()) + log = log.WithField("chat_id", chatId.String()) signature := req.Signature req.Signature = nil @@ -221,23 +188,12 @@ func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest return nil, err } - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, - }, nil - default: - log.WithError(err).Warn("failure getting chat record") - return nil, status.Error(codes.Internal, "") - } - - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + isChatMember, err := s.data.IsChatMember(ctx, chatId, memberId) if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") + log.WithError(err).Warn("failed to check if chat member") return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { + } + if !isChatMember { return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_DENIED, }, nil @@ -273,20 +229,11 @@ func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest query.WithDirection(direction), query.WithLimit(limit), ) - if err == chat.ErrMessageNotFound { - return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, - }, nil - } else if err != nil { + if err != nil { log.WithError(err).Warn("failure getting chat messages") return nil, status.Error(codes.Internal, "") } - if len(protoChatMessages) == 0 { - return &chatpb.GetMessagesResponse{ - Result: chatpb.GetMessagesResponse_MESSAGE_NOT_FOUND, - }, nil - } return &chatpb.GetMessagesResponse{ Result: chatpb.GetMessagesResponse_OK, Messages: protoChatMessages, @@ -322,9 +269,9 @@ func (s *Server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } log = log.WithField("chat_id", chatId.String()) - memberId, err := chat.GetMemberIdFromProto(req.GetOpenStream().MemberId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("failed to derive messaging account") return status.Error(codes.Internal, "") } log = log.WithField("member_id", memberId.String()) @@ -335,25 +282,12 @@ func (s *Server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e return err } - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return streamer.Send(&chatpb.StreamChatEventsResponse{ - Type: &chatpb.StreamChatEventsResponse_Error{ - Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_CHAT_NOT_FOUND}, - }, - }) - default: - log.WithError(err).Warn("failure getting chat record") - return status.Error(codes.Internal, "") - } - - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") + log.WithError(err).Warn("failed to derive messaging account") return status.Error(codes.Internal, "") - } else if !ownsChatMember { + } + if !isMember { return streamer.Send(&chatpb.StreamChatEventsResponse{ Type: &chatpb.StreamChatEventsResponse_Error{ Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_DENIED}, @@ -451,84 +385,6 @@ func (s *Server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e } } -func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { - log := s.log.WithFields(logrus.Fields{ - "method": "flushMessages", - "chat_id": chatId.String(), - "owner_account": owner.PublicKey().ToBase58(), - }) - - protoChatMessages, err := s.getProtoChatMessages( - ctx, - chatId, - owner, - query.WithCursor(query.EmptyCursor), - query.WithDirection(query.Descending), - query.WithLimit(flushMessageCount), - ) - if err == chat.ErrMessageNotFound { - return - } else if err != nil { - log.WithError(err).Warn("failure getting chat messages") - return - } - - for _, protoChatMessage := range protoChatMessages { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: protoChatMessage, - }, - } - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - return - } - } -} - -func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { - log := s.log.WithFields(logrus.Fields{ - "method": "flushPointers", - "chat_id": chatId.String(), - }) - - memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatId) - if err == chat.ErrMemberNotFound { - return - } else if err != nil { - log.WithError(err).Warn("failure getting chat members") - return - } - - for _, memberRecord := range memberRecords { - for _, optionalPointer := range []struct { - kind chat.PointerType - value *chat.MessageId - }{ - {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, - {chat.PointerTypeRead, memberRecord.ReadPointer}, - } { - if optionalPointer.value == nil { - continue - } - - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Pointer{ - Pointer: &chatpb.Pointer{ - Type: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberRecord.MemberId.ToProto(), - }, - }, - } - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - return - } - } - } -} - func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { log := s.log.WithField("method", "StartChat") log = client.InjectLoggingMetadata(ctx, log) @@ -540,124 +396,154 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive member id") + return nil, status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId) + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - // todo: Maybe expand this in the future. - if req.Self.Platform != chatpb.Platform_TWITTER { - log.Info("cannot start chat without specifying username") - return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_INVALID_PARAMETER}, nil - } - - selfVerified, err := s.ownsTwitterUsername(ctx, owner, req.Self.Username) - if err != nil { - log.WithError(err).Warn("failed to verify creators twitter") + creator, err := s.data.GetTwitterUserByTipAddress(ctx, memberId.String()) + if errors.Is(err, twitter.ErrUserNotFound) { + log.WithField("memberId", memberId).Info("User has no twitter account") + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_MISSING_IDENTITY}, nil + } else if err != nil { + log.WithError(err).Warn("failed to get twitter user") return nil, status.Error(codes.Internal, "") } - if !selfVerified { - return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil - } switch typed := req.Parameters.(type) { case *chatpb.StartChatRequest_TwoWayChat: - chatId := chat.GetChatId(owner.PublicKey().ToBase58(), base58.Encode(typed.TwoWayChat.OtherUser.Value), true) + chatId := chat.GetTwoWayChatId(memberId, typed.TwoWayChat.OtherUser.Value) + + metadata, err := s.getMetadata(ctx, memberId, chatId) + if err == nil { + return &chatpb.StartChatResponse{ + Chat: metadata, + }, nil + + } else if err != nil && !errors.Is(err, chat.ErrChatNotFound) { + log.WithError(err).Warn("failed to get chat metadata") + return nil, status.Error(codes.Internal, "") + } if typed.TwoWayChat.IntentId == nil { - /* - isFriends, err := s.data.IsFriendshipPaid(ctx, owner, typed.TwoWayChat.OtherUser) - if err != nil { - log.WithError(err).Warn("failure checking two way chat") - return nil, status.Error(codes.Internal, "") - } + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_INVALID_PARAMETER}, nil + } - if !isFriends { - return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil - } - */ - _, err = s.data.GetChatByIdV2(ctx, chatId) - if errors.Is(err, chat.ErrChatNotFound) { - return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil - } else if err != nil { - log.WithError(err).Warn("failure checking two way chat") - return nil, status.Error(codes.Internal, "") - } - } else { - intentId := base58.Encode(typed.TwoWayChat.IntentId.Value) - log = log.WithField("intent", intentId) - - intentRecord, err := s.data.GetIntent(ctx, intentId) - if errors.Is(err, intent.ErrIntentNotFound) { - log.WithError(err).Info("Intent not found") - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_INVALID_PARAMETER, - Chat: nil, - }, nil - } else if err != nil { - log.WithError(err).Warn("failure getting intent record") - return nil, status.Error(codes.Internal, "") - } + intentId := base58.Encode(typed.TwoWayChat.IntentId.Value) + log = log.WithField("intent", intentId) - if intentRecord.SendPrivatePaymentMetadata == nil { - return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil - } + intentRecord, err := s.data.GetIntent(ctx, intentId) + if errors.Is(err, intent.ErrIntentNotFound) { + log.WithError(err).Info("Intent not found") + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_INVALID_PARAMETER, + Chat: nil, + }, nil + } else if err != nil { + log.WithError(err).Warn("failure getting intent record") + return nil, status.Error(codes.Internal, "") + } - // TODO: Further verification + Enforcement - if !intentRecord.SendPrivatePaymentMetadata.IsChat { - log.Warn("intent is not for chat") - } + switch intentRecord.State { + case intent.StatePending: + return &chatpb.StartChatResponse{ + Result: chatpb.StartChatResponse_PENDING, + }, nil + case intent.StateConfirmed: + default: + log.WithField("state", intentRecord.State).Info("PayToChat intent did not succeed") + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } - expectedChatId := base64.StdEncoding.EncodeToString(chatId[:]) - if intentRecord.SendPrivatePaymentMetadata.ChatId != expectedChatId { - log.WithField("expected", expectedChatId).WithField("actual", intentRecord.SendPrivatePaymentMetadata.ChatId).Warn("chat id mismatch") - } + if intentRecord.SendPrivatePaymentMetadata == nil { + log.Info("intent missing private payment meta") + //return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + + if !intentRecord.SendPrivatePaymentMetadata.IsChat { + log.Info("intent is not for chat") + //return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_DENIED}, nil + } + + expectedChatId := base58.Encode(chatId[:]) + if intentRecord.SendPrivatePaymentMetadata.ChatId != expectedChatId { + log.WithField("expected", expectedChatId).WithField("actual", intentRecord.SendPrivatePaymentMetadata.ChatId).Warn("chat id mismatch") + } + + otherMessagingAddress := base58.Encode(typed.TwoWayChat.OtherUser.Value) + + otherTwitter, err := s.data.GetTwitterUserByTipAddress(ctx, otherMessagingAddress) + if errors.Is(err, twitter.ErrUserNotFound) { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_USER_NOT_FOUND}, nil + } else if err != nil { + log.WithError(err).Warn("failure checking twitter user") + return nil, status.Error(codes.Internal, "") + } + + otherAccount, err := s.data.GetAccountInfoByTokenAddress(ctx, otherMessagingAddress) + if errors.Is(err, account.ErrAccountInfoNotFound) { + return &chatpb.StartChatResponse{Result: chatpb.StartChatResponse_USER_NOT_FOUND}, nil + } else if err != nil { + log.WithError(err).Warn("failure checking account info") + return nil, status.Error(codes.Internal, "") } // At this point, we assume the relationship is valid, and can proceed to recover or create // the chat record. creationTs := time.Now() - chatRecord := &chat.ChatRecord{ - ChatId: chatId, - ChatType: chat.ChatTypeTwoWay, - IsVerified: true, - CreatedAt: creationTs, + chatRecord := &chat.MetadataRecord{ + ChatId: chatId, + ChatType: chat.ChatTypeTwoWay, + CreatedAt: creationTs, + ChatTitle: nil, } + memberRecords := []*chat.MemberRecord{ { ChatId: chatId, - MemberId: chat.GenerateMemberId(), + MemberId: memberId.String(), + Owner: owner.PublicKey().ToBase58(), - Platform: chat.PlatformTwitter, - PlatformId: req.Self.Username, - OwnerAccount: owner.PublicKey().ToBase58(), + Platform: chat.PlatformTwitter, + PlatformId: creator.Username, JoinedAt: creationTs, }, { ChatId: chatId, - MemberId: chat.GenerateMemberId(), + MemberId: otherMessagingAddress, + Owner: otherAccount.OwnerAccount, - Platform: chat.PlatformTwitter, - PlatformId: typed.TwoWayChat.Identity.Username, - OwnerAccount: base58.Encode(typed.TwoWayChat.OtherUser.Value), + Platform: chat.PlatformTwitter, + PlatformId: otherTwitter.Username, JoinedAt: time.Now(), }, } + // Note: this should almost _always_ succeed in the happy path, since we check + // for existence earlier! + // + // The only time we have to rollback and query is on race of creation. err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { - existingChatRecord, err := s.data.GetChatByIdV2(ctx, chatId) + existingChatRecord, err := s.data.GetChatMetadata(ctx, chatId) if err != nil && !errors.Is(err, chat.ErrChatNotFound) { return fmt.Errorf("failed to check existing chat: %w", err) } if existingChatRecord != nil { chatRecord = existingChatRecord - memberRecords, err = s.data.GetAllChatMembersV2(ctx, chatId) + memberRecords, err = s.data.GetChatMembersV2(ctx, chatId) if err != nil { - return fmt.Errorf("failed to get members of existing chat: %w", err) + return fmt.Errorf("failed to check existing chat members: %w", err) } return nil @@ -679,23 +565,15 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* return nil, status.Error(codes.Internal, "") } - protoChat, err := s.toProtoChat( - ctx, - chatRecord, - memberRecords, - map[chat.Platform]string{ - chat.PlatformCode: owner.PublicKey().ToBase58(), - chat.PlatformTwitter: req.Self.Username, - }, - ) + md, err := s.populateMetadata(ctx, chatRecord, memberRecords, memberId) if err != nil { - log.WithError(err).Warn("failure constructing proto chat message") + log.WithError(err).Warn("failure populating metadata") return nil, status.Error(codes.Internal, "") } return &chatpb.StartChatResponse{ Result: chatpb.StartChatResponse_OK, - Chat: protoChat, + Chat: md, }, nil default: @@ -714,22 +592,22 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid chat id") + log.WithError(err).Warn("failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("chat_id", chatId.String()) + log = log.WithField("member_id", memberId) - memberId, err := chat.GetMemberIdFromProto(req.MemberId) + chatId, err := chat.GetChatIdFromProto(req.ChatId) if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("invalid chat id") return nil, status.Error(codes.Internal, "") } - log = log.WithField("member_id", memberId.String()) + log = log.WithField("chat_id", chatId.String()) switch req.Content[0].Type.(type) { - case *chatpb.Content_Text, *chatpb.Content_ThankYou: + case *chatpb.Content_Text: default: return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_INVALID_CONTENT_TYPE, @@ -742,33 +620,22 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest return nil, err } - chatRecord, err := s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: + metadata, err := s.data.GetChatMetadata(ctx, chatId) + if errors.Is(err, chat.ErrChatNotFound) { return &chatpb.SendMessageResponse{ - Result: chatpb.SendMessageResponse_CHAT_NOT_FOUND, + Result: chatpb.SendMessageResponse_DENIED, }, nil - default: + } else if err != nil { log.WithError(err).Warn("failure getting chat record") return nil, status.Error(codes.Internal, "") } - var chatTitle string - switch chatRecord.ChatType { - case chat.ChatTypeTwoWay: - chatTitle = chatv2.TwoWayChatName - default: - return &chatpb.SendMessageResponse{ - Result: chatpb.SendMessageResponse_INVALID_CHAT_TYPE, - }, nil - } - - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") + log.WithError(err).Warn("failure checking member record") return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { + } + if !isMember { return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_DENIED, }, nil @@ -787,7 +654,7 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest } s.onPersistChatMessage(log, chatId, chatMessage) - s.sendPushNotifications(chatId, chatTitle, memberId, chatMessage) + s.sendPushNotifications(chatId, pointer.StringOrEmpty(metadata.ChatTitle), memberId, chatMessage) return &chatpb.SendMessageResponse{ Result: chatpb.SendMessageResponse_OK, @@ -795,100 +662,35 @@ func (s *Server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest }, nil } -// TODO(api): This likely needs an RPC that can be called from any other Server. -func (s *Server) NotifyMessage(_ context.Context, chatID chat.ChatId, message *chatpb.ChatMessage) { - log := s.log.WithFields(logrus.Fields{ - "chat_id": chatID.String(), - "messge_id": message.MessageId.String(), - }) - - s.onPersistChatMessage(log, chatID, message) -} - -// todo: This belongs in the common chat utility, which currently only operates on v1 chats -func (s *Server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.ChatMessage) error { - if err := protoChatMessage.Validate(); err != nil { - return errors.Wrap(err, "proto chat message failed validation") - } - - messageId, err := chat.GetMessageIdFromProto(protoChatMessage.MessageId) - if err != nil { - return errors.Wrap(err, "invalid message id") - } - - var senderId *chat.MemberId - if protoChatMessage.SenderId != nil { - convertedSenderId, err := chat.GetMemberIdFromProto(protoChatMessage.SenderId) - if err != nil { - return errors.Wrap(err, "invalid member id") - } - senderId = &convertedSenderId - } - - // Clear out extracted metadata as a space optimization - cloned := proto.Clone(protoChatMessage).(*chatpb.ChatMessage) - cloned.MessageId = nil - cloned.SenderId = nil - cloned.Ts = nil - cloned.Cursor = nil - - marshalled, err := proto.Marshal(cloned) - if err != nil { - return errors.Wrap(err, "error marshalling proto chat message") - } - - // todo: Doesn't incoroporate reference. We might want to promote the proto a level above the content. - messageRecord := &chat.MessageRecord{ - ChatId: chatId, - MessageId: messageId, - - Sender: senderId, - - Data: marshalled, - - IsSilent: false, - } - - err = s.data.PutChatMessageV2(ctx, messageRecord) - if err != nil { - return errors.Wrap(err, "error persiting chat message") - } - return nil -} - func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerRequest) (*chatpb.AdvancePointerResponse, error) { log := s.log.WithField("method", "AdvancePointer") log = client.InjectLoggingMetadata(ctx, log) - owner, err := common.NewAccountFromProto(req.Owner) + chatId, err := chat.GetChatIdFromProto(req.ChatId) if err != nil { - log.WithError(err).Warn("invalid owner account") + log.WithError(err).Warn("invalid chat id") return nil, status.Error(codes.Internal, "") } - log = log.WithField("owner_account", owner.PublicKey().ToBase58()) + log = log.WithField("chat_id", chatId.String()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + owner, err := common.NewAccountFromProto(req.Owner) if err != nil { - log.WithError(err).Warn("invalid chat id") + log.WithError(err).Warn("invalid owner account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("chat_id", chatId.String()) + log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - memberId, err := chat.GetMemberIdFromProto(req.Pointer.MemberId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("member_id", memberId.String()) + log = log.WithField("member_id", memberId) pointerType := chat.GetPointerTypeFromProto(req.Pointer.Type) log = log.WithField("pointer_type", pointerType.String()) - switch pointerType { - case chat.PointerTypeDelivered, chat.PointerTypeRead: - default: - return &chatpb.AdvancePointerResponse{ - Result: chatpb.AdvancePointerResponse_INVALID_POINTER_TYPE, - }, nil + if pointerType <= chat.PointerTypeUnknown || pointerType > chat.PointerTypeRead { + return nil, status.Error(codes.InvalidArgument, "invalid pointer type") } pointerValue, err := chat.GetMessageIdFromProto(req.Pointer.Value) @@ -898,42 +700,30 @@ func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } log = log.WithField("pointer_value", pointerValue.String()) + // Force override whatever the user thought it should be. + req.Pointer.MemberId = memberId.ToProto() + signature := req.Signature req.Signature = nil if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { return nil, err } - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return &chatpb.AdvancePointerResponse{ - Result: chatpb.AdvancePointerResponse_CHAT_NOT_FOUND, - }, nil - default: - log.WithError(err).Warn("failure getting chat record") + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") return nil, status.Error(codes.Internal, "") } + if !isMember { + return &chatpb.AdvancePointerResponse{Result: chatpb.AdvancePointerResponse_DENIED}, nil + } - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) - if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") - return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { - return &chatpb.AdvancePointerResponse{ - Result: chatpb.AdvancePointerResponse_DENIED, - }, nil - } - - _, err = s.data.GetChatMessageByIdV2(ctx, chatId, pointerValue) - switch err { - case nil: - case chat.ErrMessageNotFound: + _, err = s.data.GetChatMessageV2(ctx, chatId, pointerValue) + if errors.Is(err, chat.ErrChatNotFound) { return &chatpb.AdvancePointerResponse{ Result: chatpb.AdvancePointerResponse_MESSAGE_NOT_FOUND, }, nil - default: + } else if err != nil { log.WithError(err).Warn("failure getting chat message record") return nil, status.Error(codes.Internal, "") } @@ -960,8 +750,8 @@ func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR }, nil } -func (s *Server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityRequest) (*chatpb.RevealIdentityResponse, error) { - log := s.log.WithField("method", "RevealIdentity") +func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { + log := s.log.WithField("method", "SetMuteState") log = client.InjectLoggingMetadata(ctx, log) owner, err := common.NewAccountFromProto(req.Owner) @@ -971,19 +761,19 @@ func (s *Server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid chat id") + log.WithError(err).Warn("failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("chat_id", chatId.String()) + log = log.WithField("member_id", memberId.String()) - memberId, err := chat.GetMemberIdFromProto(req.MemberId) + chatId, err := chat.GetChatIdFromProto(req.ChatId) if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("invalid chat id") return nil, status.Error(codes.Internal, "") } - log = log.WithField("member_id", memberId.String()) + log = log.WithField("chat_id", chatId.String()) signature := req.Signature req.Signature = nil @@ -991,131 +781,30 @@ func (s *Server) RevealIdentity(ctx context.Context, req *chatpb.RevealIdentityR return nil, err } - platform := chat.GetPlatformFromProto(req.Identity.Platform) - - log = log.WithFields(logrus.Fields{ - "platform": platform.String(), - "username": req.Identity.Username, - }) - - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_CHAT_NOT_FOUND, - }, nil - default: - log.WithError(err).Warn("failure getting chat record") - return nil, status.Error(codes.Internal, "") - } - - memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) - switch err { - case nil: - case chat.ErrMemberNotFound: - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_DENIED, - }, nil - default: - log.WithError(err).Warn("failure getting member record") - return nil, status.Error(codes.Internal, "") - } - - ownsChatMember, err := s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") + log.WithError(err).Warn("failure checking member record") return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_DENIED, - }, nil } - - switch platform { - case chat.PlatformTwitter: - ownsUsername, err := s.ownsTwitterUsername(ctx, owner, req.Identity.Username) - if err != nil { - log.WithError(err).Warn("failure determing twitter username ownership") - return nil, status.Error(codes.Internal, "") - } else if !ownsUsername { - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_DENIED, - }, nil - } - default: - return nil, status.Error(codes.InvalidArgument, "RevealIdentityRequest.Identity.Platform must be TWITTER") + if !isMember { + return &chatpb.SetMuteStateResponse{Result: chatpb.SetMuteStateResponse_DENIED}, nil } - // Idempotent RPC call using the same platform and username - if memberRecord.Platform == platform && memberRecord.PlatformId == req.Identity.Username { - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_OK, - }, nil - } - - // Identity was already revealed, and it isn't the specified platform and username - if memberRecord.Platform != chat.PlatformCode { - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, - }, nil - } - - chatLock := s.chatLocks.Get(chatId[:]) - chatLock.Lock() - defer chatLock.Unlock() - - chatMessage := newProtoChatMessage( - memberId, - &chatpb.Content{ - Type: &chatpb.Content_IdentityRevealed{ - IdentityRevealed: &chatpb.IdentityRevealedContent{ - MemberId: req.MemberId, - Identity: req.Identity, - }, - }, - }, - ) - - err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { - err = s.data.UpgradeChatMemberIdentityV2(ctx, chatId, memberId, platform, req.Identity.Username) - switch err { - case nil: - case chat.ErrMemberIdentityAlreadyUpgraded: - return err - default: - return errors.Wrap(err, "error updating chat member identity") - } - - err := s.persistChatMessage(ctx, chatId, chatMessage) - if err != nil { - return errors.Wrap(err, "error persisting chat message") - } - return nil - }) - - if err == nil { - s.onPersistChatMessage(log, chatId, chatMessage) - } + // todo: Use chat record to determine if muting is allowed - switch err { - case nil: - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_OK, - Message: chatMessage, - }, nil - case chat.ErrMemberIdentityAlreadyUpgraded: - return &chatpb.RevealIdentityResponse{ - Result: chatpb.RevealIdentityResponse_DIFFERENT_IDENTITY_REVEALED, - }, nil - default: - log.WithError(err).Warn("failure upgrading chat member identity") + err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) + if err != nil { + log.WithError(err).Warn("failure setting mute state") return nil, status.Error(codes.Internal, "") } + + return &chatpb.SetMuteStateResponse{ + Result: chatpb.SetMuteStateResponse_OK, + }, nil } -func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateRequest) (*chatpb.SetMuteStateResponse, error) { - log := s.log.WithField("method", "SetMuteState") +func (s *Server) NotifyIsTyping(ctx context.Context, req *chatpb.NotifyIsTypingRequest) (*chatpb.NotifyIsTypingResponse, error) { + log := s.log.WithField("method", "NotifyIsTyping") log = client.InjectLoggingMetadata(ctx, log) owner, err := common.NewAccountFromProto(req.Owner) @@ -1125,19 +814,19 @@ func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque } log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + memberId, err := owner.ToChatMemberId() if err != nil { - log.WithError(err).Warn("invalid chat id") + log.WithError(err).Warn("failed to derive messaging account") return nil, status.Error(codes.Internal, "") } - log = log.WithField("chat_id", chatId.String()) + log = log.WithField("member_id", memberId) - memberId, err := chat.GetMemberIdFromProto(req.MemberId) + chatId, err := chat.GetChatIdFromProto(req.ChatId) if err != nil { - log.WithError(err).Warn("invalid member id") + log.WithError(err).Warn("invalid chat id") return nil, status.Error(codes.Internal, "") } - log = log.WithField("member_id", memberId.String()) + log = log.WithField("chat_id", chatId.String()) signature := req.Signature req.Signature = nil @@ -1145,160 +834,165 @@ func (s *Server) SetMuteState(ctx context.Context, req *chatpb.SetMuteStateReque return nil, err } - // todo: Use chat record to determine if muting is allowed - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return &chatpb.SetMuteStateResponse{ - Result: chatpb.SetMuteStateResponse_CHAT_NOT_FOUND, - }, nil - default: - log.WithError(err).Warn("failure getting chat record") + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failure checking member record") return nil, status.Error(codes.Internal, "") } + if !isMember { + return &chatpb.NotifyIsTypingResponse{Result: chatpb.NotifyIsTypingResponse_DENIED}, nil + } - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) - if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") - return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { - return &chatpb.SetMuteStateResponse{ - Result: chatpb.SetMuteStateResponse_DENIED, - }, nil + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_IsTyping{ + IsTyping: &chatpb.IsTyping{ + MemberId: memberId.ToProto(), + IsTyping: req.IsTyping, + }, + }, } - err = s.data.SetChatMuteStateV2(ctx, chatId, memberId, req.IsMuted) - if err != nil { - log.WithError(err).Warn("failure setting mute state") - return nil, status.Error(codes.Internal, "") + if err := s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failure notifying chat event") } - return &chatpb.SetMuteStateResponse{ - Result: chatpb.SetMuteStateResponse_OK, - }, nil + return &chatpb.NotifyIsTypingResponse{}, nil } -func (s *Server) SetSubscriptionState(ctx context.Context, req *chatpb.SetSubscriptionStateRequest) (*chatpb.SetSubscriptionStateResponse, error) { - log := s.log.WithField("method", "SetSubscriptionState") - log = client.InjectLoggingMetadata(ctx, log) - - owner, err := common.NewAccountFromProto(req.Owner) +// todo: needs to have a 'fill' version +func (s *Server) getMetadata(ctx context.Context, asMember chat.MemberId, id chat.ChatId) (*chatpb.Metadata, error) { + mdRecord, err := s.data.GetChatMetadata(ctx, id) if err != nil { - log.WithError(err).Warn("invalid owner account") - return nil, status.Error(codes.Internal, "") + return nil, fmt.Errorf("failed to lookup metadata: %w", err) } - log = log.WithField("owner_account", owner.PublicKey().ToBase58()) - chatId, err := chat.GetChatIdFromProto(req.ChatId) + members, err := s.data.GetChatMembersV2(ctx, id) if err != nil { - log.WithError(err).Warn("invalid chat id") - return nil, status.Error(codes.Internal, "") + return nil, fmt.Errorf("failed to get members: %w", err) } - log = log.WithField("chat_id", chatId.String()) - memberId, err := chat.GetMemberIdFromProto(req.MemberId) - if err != nil { - log.WithError(err).Warn("invalid member id") - return nil, status.Error(codes.Internal, "") - } - log = log.WithField("member_id", memberId.String()) + return s.populateMetadata(ctx, mdRecord, members, asMember) +} - signature := req.Signature - req.Signature = nil - if err := s.auth.Authenticate(ctx, owner, req, signature); err != nil { - return nil, err +func (s *Server) populateMetadata(ctx context.Context, mdRecord *chat.MetadataRecord, members []*chat.MemberRecord, asMember chat.MemberId) (*chatpb.Metadata, error) { + md := &chatpb.Metadata{ + ChatId: mdRecord.ChatId.ToProto(), + Type: mdRecord.ChatType.ToProto(), + Cursor: &chatpb.Cursor{Value: mdRecord.ChatId[:]}, + IsMuted: false, + Muteable: false, + NumUnread: 0, } - // todo: Use chat record to determine if unsubscribing is allowed - _, err = s.data.GetChatByIdV2(ctx, chatId) - switch err { - case nil: - case chat.ErrChatNotFound: - return &chatpb.SetSubscriptionStateResponse{ - Result: chatpb.SetSubscriptionStateResponse_CHAT_NOT_FOUND, - }, nil - default: - log.WithError(err).Warn("failure getting chat record") - return nil, status.Error(codes.Internal, "") + if mdRecord.ChatTitle != nil { + md.Title = *mdRecord.ChatTitle } - ownsChatMember, err := s.ownsChatMemberWithoutRecord(ctx, chatId, memberId, owner) - if err != nil { - log.WithError(err).Warn("failure determing chat member ownership") - return nil, status.Error(codes.Internal, "") - } else if !ownsChatMember { - return &chatpb.SetSubscriptionStateResponse{ - Result: chatpb.SetSubscriptionStateResponse_DENIED, - }, nil - } + for _, m := range members { + memberId, err := chat.GetMemberIdFromString(m.MemberId) + if err != nil { + return nil, fmt.Errorf("invalid member id %q: %w", m.MemberId, err) + } - err = s.data.SetChatSubscriptionStateV2(ctx, chatId, memberId, req.IsSubscribed) - if err != nil { - log.WithError(err).Warn("failure setting mute state") - return nil, status.Error(codes.Internal, "") + member := &chatpb.Member{ + MemberId: memberId.ToProto(), + } + md.Members = append(md.Members, member) + + twitterUser, err := s.data.GetTwitterUserByTipAddress(ctx, m.MemberId) + if errors.Is(err, twitter.ErrUserNotFound) { + s.log.WithField("member", m.MemberId).Info("Twitter user not found for existing user") + } else if err != nil { + // TODO: If client have caching, we could just not do this... + return nil, fmt.Errorf("failed to get twitter user: %w", err) + } else { + member.Identity = &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: twitterUser.Username, + DisplayName: twitterUser.Name, + ProfilePicUrl: twitterUser.ProfilePicUrl, + } + } + + if m.DeliveryPointer != nil { + member.Pointers = append(member.Pointers, &chatpb.Pointer{ + Type: chatpb.PointerType_DELIVERED, + Value: m.DeliveryPointer.ToProto(), + MemberId: memberId.ToProto(), + }) + } + if m.ReadPointer != nil { + member.Pointers = append(member.Pointers, &chatpb.Pointer{ + Type: chatpb.PointerType_READ, + Value: m.ReadPointer.ToProto(), + MemberId: memberId.ToProto(), + }) + } + + md.IsMuted = m.IsMuted + + // If the member is not the requestor, then we can skip further processing + if !bytes.Equal(asMember, memberId) { + continue + } + + // TODO: Do we actually want to compute this feature? It's maybe non-trivial. + // Maybe should have a safety valve at minimum. + md.NumUnread, err = s.data.GetChatUnreadCountV2(ctx, mdRecord.ChatId, memberId, m.ReadPointer) + if err != nil { + return nil, fmt.Errorf("failed to get unread count: %w", err) + } } - return &chatpb.SetSubscriptionStateResponse{ - Result: chatpb.SetSubscriptionStateResponse_OK, - }, nil + return md, nil } -func (s *Server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, memberRecords []*chat.MemberRecord, myIdentitiesByPlatform map[chat.Platform]string) (*chatpb.ChatMetadata, error) { - protoChat := &chatpb.ChatMetadata{ - ChatId: chatRecord.ChatId.ToProto(), - Type: chatRecord.ChatType.ToProto(), - Cursor: &chatpb.Cursor{Value: query.ToCursor(uint64(chatRecord.Id))}, - } +func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushMessages", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) - switch chatRecord.ChatType { - case chat.ChatTypeTwoWay: - protoChat.Title = "Tip Chat" // todo: proper title with localization + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(query.EmptyCursor), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } - protoChat.CanMute = true - protoChat.CanUnsubscribe = true - case chat.ChatTypeNotification: - if chatRecord.ChatTitle == nil { - // TODO: we shouldn't fail the whole RPC - return nil, fmt.Errorf("invalid notification chat: missing title") + for _, protoChatMessage := range protoChatMessages { + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Message{ + Message: protoChatMessage, + }, } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} - // TODO: Localization - protoChat.Title = *chatRecord.ChatTitle +func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushPointers", + "chat_id": chatId.String(), + }) - default: - return nil, errors.Errorf("unsupported chat type: %s", chatRecord.ChatType.String()) + memberRecords, err := s.data.GetChatMembersV2(ctx, chatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return } for _, memberRecord := range memberRecords { - var isSelf bool - var identity *chatpb.ChatMemberIdentity - switch memberRecord.Platform { - case chat.PlatformCode: - myPublicKey, ok := myIdentitiesByPlatform[chat.PlatformCode] - isSelf = ok && myPublicKey == memberRecord.PlatformId - case chat.PlatformTwitter: - myTwitterUsername, ok := myIdentitiesByPlatform[chat.PlatformTwitter] - isSelf = ok && myTwitterUsername == memberRecord.PlatformId - - profilePicUrl := "" - user, err := s.data.GetTwitterUserByUsername(ctx, memberRecord.PlatformId) - if err != nil { - s.log.WithError(err).WithField("method", "toProtoChat").Warn("Failed to get twitter user for member record") - } else { - profilePicUrl = user.ProfilePicUrl - } - - identity = &chatpb.ChatMemberIdentity{ - Platform: memberRecord.Platform.ToProto(), - Username: memberRecord.PlatformId, - ProfilePicUrl: profilePicUrl, - } - default: - return nil, errors.Errorf("unsupported platform type: %s", memberRecord.Platform.String()) - } - - var pointers []*chatpb.Pointer for _, optionalPointer := range []struct { kind chat.PointerType value *chat.MessageId @@ -1310,57 +1004,40 @@ func (s *Server) toProtoChat(ctx context.Context, chatRecord *chat.ChatRecord, m continue } - pointers = append(pointers, &chatpb.Pointer{ - Type: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberRecord.MemberId.ToProto(), - }) - } + memberId, err := chat.GetMemberIdFromString(memberRecord.MemberId) + if err != nil { + log.WithError(err).Warnf("failure getting memberId from %s", memberRecord.MemberId) + return + } - protoMember := &chatpb.ChatMember{ - MemberId: memberRecord.MemberId.ToProto(), - IsSelf: isSelf, - Identity: identity, - Pointers: pointers, - } - if protoMember.IsSelf { - protoMember.IsMuted = memberRecord.IsMuted - protoMember.IsSubscribed = !memberRecord.IsUnsubscribed - - if !memberRecord.IsUnsubscribed { - readPointer := chat.GenerateMessageIdAtTime(time.Unix(0, 0)) - if memberRecord.ReadPointer != nil { - readPointer = *memberRecord.ReadPointer - } - unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, memberRecord.MemberId, readPointer) - if err != nil { - return nil, errors.Wrap(err, "error calculating unread count") - } - protoMember.NumUnread = unreadCount + event := &chatpb.ChatStreamEvent{ + Type: &chatpb.ChatStreamEvent_Pointer{ + Pointer: &chatpb.Pointer{ + Type: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberId.ToProto(), + }, + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return } } - - protoChat.Members = append(protoChat.Members, protoMember) } - - return protoChat, nil } -func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.ChatMessage, error) { - messageRecords, err := s.data.GetAllChatMessagesV2( - ctx, - chatId, - queryOptions..., - ) - if err == chat.ErrMessageNotFound { - return nil, err +func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.Message, error) { + messageRecords, err := s.data.GetAllChatMessagesV2(ctx, chatId, queryOptions...) + if err != nil { + return nil, fmt.Errorf("failure getting chat messages: %w", err) } var userLocale *language.Tag // Loaded lazily when required - var res []*chatpb.ChatMessage + var res []*chatpb.Message for _, messageRecord := range messageRecords { - var protoChatMessage chatpb.ChatMessage - err = proto.Unmarshal(messageRecord.Data, &protoChatMessage) + var protoChatMessage chatpb.Message + err = proto.Unmarshal(messageRecord.Payload, &protoChatMessage) if err != nil { return nil, errors.Wrap(err, "error unmarshalling proto chat message") } @@ -1402,7 +1079,54 @@ func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, o return res, nil } -func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.ChatMessage) { +// todo: This belongs in the common chat utility, which currently only operates on v1 chats +func (s *Server) persistChatMessage(ctx context.Context, chatId chat.ChatId, protoChatMessage *chatpb.Message) error { + if err := protoChatMessage.Validate(); err != nil { + return errors.Wrap(err, "proto chat message failed validation") + } + + messageId, err := chat.GetMessageIdFromProto(protoChatMessage.MessageId) + if err != nil { + return errors.Wrap(err, "invalid message id") + } + + var senderId *chat.MemberId + if protoChatMessage.SenderId != nil { + convertedSenderId, err := chat.GetMemberIdFromProto(protoChatMessage.SenderId) + if err != nil { + return errors.Wrap(err, "invalid member id") + } + senderId = &convertedSenderId + } + + // Clear out extracted metadata as a space optimization + cloned := proto.Clone(protoChatMessage).(*chatpb.Message) + cloned.MessageId = nil + cloned.SenderId = nil + cloned.Ts = nil + cloned.Cursor = nil + + marshalled, err := proto.Marshal(cloned) + if err != nil { + return errors.Wrap(err, "error marshalling proto chat message") + } + + messageRecord := &chat.MessageRecord{ + ChatId: chatId, + MessageId: messageId, + Sender: senderId, + Payload: marshalled, + IsSilent: false, + } + + err = s.data.PutChatMessageV2(ctx, messageRecord) + if err != nil { + return errors.Wrap(err, "error persisting chat message") + } + return nil +} + +func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.Message) { event := &chatpb.ChatStreamEvent{ Type: &chatpb.ChatStreamEvent_Message{ Message: chatMessage, @@ -1414,7 +1138,7 @@ func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, cha } } -func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sender chat.MemberId, message *chatpb.ChatMessage) { +func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sender chat.MemberId, message *chatpb.Message) { log := s.log.WithFields(logrus.Fields{ "method": "sendPushNotifications", "sender": sender.String(), @@ -1422,7 +1146,7 @@ func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sen }) // TODO: Callers might already have this loaded. - members, err := s.data.GetAllChatMembersV2(context.Background(), chatId) + members, err := s.data.GetChatMembersV2(context.Background(), chatId) if err != nil { log.WithError(err).Warn("failure getting chat members") return @@ -1432,19 +1156,19 @@ func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sen eg.SetLimit(min(32, len(members))) for _, m := range members { - if m.MemberId == sender || m.IsMuted || m.IsUnsubscribed { + if m.Owner == "" || m.IsMuted { continue } - owner, err := common.NewAccountFromPublicKeyString(m.GetOwner()) + owner, err := common.NewAccountFromPublicKeyString(m.Owner) if err != nil { - log.WithError(err).WithField("member", m.MemberId.String()).Warn("failure getting owner") + log.WithError(err).WithField("member", m.MemberId).Warn("failure getting owner") continue } m := m eg.Go(func() error { - log.WithField("member", m.MemberId.String()).Info("sending push notification") + log.WithField("member", m.MemberId).Info("sending push notification") err = push_util.SendChatMessagePushNotificationV2( context.Background(), s.data, @@ -1468,88 +1192,11 @@ func (s *Server) sendPushNotifications(chatId chat.ChatId, chatTitle string, sen _ = eg.Wait() } -func (s *Server) getAllIdentities(ctx context.Context, owner *common.Account) (map[chat.Platform]string, error) { - identities := map[chat.Platform]string{ - chat.PlatformCode: owner.PublicKey().ToBase58(), - } - - twitterUserame, ok, err := s.getOwnedTwitterUsername(ctx, owner) - if err != nil { - return nil, err - } - if ok { - identities[chat.PlatformTwitter] = twitterUserame - } - - return identities, nil -} - -func (s *Server) ownsChatMemberWithoutRecord(ctx context.Context, chatId chat.ChatId, memberId chat.MemberId, owner *common.Account) (bool, error) { - memberRecord, err := s.data.GetChatMemberByIdV2(ctx, chatId, memberId) - switch err { - case nil: - case chat.ErrMemberNotFound: - return false, nil - default: - return false, errors.Wrap(err, "error getting member record") - } - - return s.ownsChatMemberWithRecord(ctx, chatId, memberRecord, owner) -} - -func (s *Server) ownsChatMemberWithRecord(ctx context.Context, chatId chat.ChatId, memberRecord *chat.MemberRecord, owner *common.Account) (bool, error) { - switch memberRecord.Platform { - case chat.PlatformCode: - return memberRecord.PlatformId == owner.PublicKey().ToBase58(), nil - case chat.PlatformTwitter: - return s.ownsTwitterUsername(ctx, owner, memberRecord.PlatformId) - default: - return false, nil - } -} - -// todo: This logic should live elsewhere in somewhere more common -func (s *Server) ownsTwitterUsername(ctx context.Context, owner *common.Account, username string) (bool, error) { - ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) - if err != nil { - return false, errors.Wrap(err, "error deriving twitter tip address") - } - - twitterRecord, err := s.data.GetTwitterUserByUsername(ctx, username) - switch err { - case nil: - case twitter.ErrUserNotFound: - return false, nil - default: - return false, errors.Wrap(err, "error getting twitter user") - } - - return twitterRecord.TipAddress == ownerTipAccount.PublicKey().ToBase58(), nil -} - -// todo: This logic should live elsewhere in somewhere more common -func (s *Server) getOwnedTwitterUsername(ctx context.Context, owner *common.Account) (string, bool, error) { - ownerTipAccount, err := owner.ToTimelockVault(timelock_token.DataVersion1, common.KinMintAccount) - if err != nil { - return "", false, errors.Wrap(err, "error deriving twitter tip address") - } - - twitterRecord, err := s.data.GetTwitterUserByTipAddress(ctx, ownerTipAccount.PublicKey().ToBase58()) - switch err { - case nil: - return twitterRecord.Username, true, nil - case twitter.ErrUserNotFound: - return "", false, nil - default: - return "", false, errors.Wrap(err, "error getting twitter user") - } -} - -func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.ChatMessage { +func newProtoChatMessage(sender chat.MemberId, content ...*chatpb.Content) *chatpb.Message { messageId := chat.GenerateMessageId() ts, _ := messageId.GetTimestamp() - return &chatpb.ChatMessage{ + return &chatpb.Message{ MessageId: messageId.ToProto(), SenderId: sender.ToProto(), Content: content, diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go index aacc4f95..951f5a21 100644 --- a/pkg/code/server/grpc/chat/v2/server_test.go +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -1 +1,376 @@ package chat_v2 + +import ( + "bytes" + "context" + "fmt" + "slices" + "testing" + "time" + + "github.com/mr-tron/base58" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + + auth_util "github.com/code-payments/code-server/pkg/code/auth" + "github.com/code-payments/code-server/pkg/code/common" + "github.com/code-payments/code-server/pkg/code/data" + "github.com/code-payments/code-server/pkg/code/data/account" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/code/data/intent" + "github.com/code-payments/code-server/pkg/code/data/twitter" + "github.com/code-payments/code-server/pkg/currency" + pushmemory "github.com/code-payments/code-server/pkg/push/memory" + "github.com/code-payments/code-server/pkg/testutil" + "github.com/stretchr/testify/require" +) + +func TestServerHappy(t *testing.T) { + env, cleanup := setup(t) + defer cleanup() + + userA := testutil.NewRandomAccount(t) + userB := testutil.NewRandomAccount(t) + + ctx := context.Background() + for i, u := range []*common.Account{userA, userB} { + tipAddr, err := u.ToMessagingAccount(common.KinMintAccount) + require.NoError(t, err) + + userSuffix := string(rune('a' + i)) + + err = env.data.SaveTwitterUser(ctx, &twitter.Record{ + Username: fmt.Sprintf("username-%s", userSuffix), + Name: fmt.Sprintf("name-%s", userSuffix), + ProfilePicUrl: fmt.Sprintf("pp-%s", userSuffix), + TipAddress: tipAddr.PublicKey().ToBase58(), + LastUpdatedAt: time.Now(), + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + err = env.data.CreateAccountInfo(ctx, &account.Record{ + OwnerAccount: u.String(), + AuthorityAccount: u.String(), + TokenAccount: base58.Encode(u.MustToChatMemberId()), + MintAccount: common.KinMintAccount.String(), + AccountType: commonpb.AccountType_PRIMARY, + CreatedAt: time.Now(), + }) + require.NoError(t, err) + } + + chatId := chat.GetTwoWayChatId(userA.MustToChatMemberId(), userB.MustToChatMemberId()) + intentId := bytes.Repeat([]byte{1}, 32) + err := env.data.SaveIntent(ctx, &intent.Record{ + IntentId: base58.Encode(intentId), + IntentType: intent.SendPrivatePayment, + InitiatorOwnerAccount: userA.String(), + SendPrivatePaymentMetadata: &intent.SendPrivatePaymentMetadata{ + DestinationTokenAccount: userB.String(), + Quantity: 10, + ExchangeCurrency: currency.USD, + ExchangeRate: 10, + UsdMarketValue: 10.0, + NativeAmount: 1, + IsChat: true, + ChatId: base58.Encode(chatId[:]), + }, + State: intent.StateConfirmed, + CreatedAt: time.Now(), + }) + require.NoError(t, err) + + t.Run("Initial State", func(t *testing.T) { + req := &chatpb.GetChatsRequest{Owner: userA.ToProto()} + req.Signature = signProtoMessage(t, req, userA, false) + + chats, err := env.client.GetChats(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.Empty(t, chats.Chats) + }) + + t.Run("StartChat", func(t *testing.T) { + req := &chatpb.StartChatRequest{ + Owner: userA.ToProto(), + Parameters: &chatpb.StartChatRequest_TwoWayChat{ + TwoWayChat: &chatpb.StartTwoWayChatParameters{ + OtherUser: &commonpb.SolanaAccountId{Value: userB.MustToChatMemberId()}, + IntentId: &commonpb.IntentId{Value: intentId}, + }, + }, + } + req.Signature = signProtoMessage(t, req, userA, false) + + resp, err := env.client.StartChat(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.StartChatResponse_OK, resp.Result) + require.NotEmpty(t, resp.GetChat().GetChatId()) + + expectedMeta := &chatpb.Metadata{ + ChatId: resp.Chat.ChatId, + Type: chatpb.ChatType_TWO_WAY, + Cursor: &chatpb.Cursor{Value: resp.Chat.ChatId.Value}, + Title: "", + Members: []*chatpb.Member{ + { + MemberId: userA.MustToChatMemberId().ToProto(), + Identity: &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: "username-a", + DisplayName: "name-a", + ProfilePicUrl: "pp-a", + }, + }, + { + MemberId: userB.MustToChatMemberId().ToProto(), + Identity: &chatpb.MemberIdentity{ + Platform: chatpb.Platform_TWITTER, + Username: "username-b", + DisplayName: "name-b", + ProfilePicUrl: "pp-b", + }, + }, + }, + } + + slices.SortFunc(expectedMeta.Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + slices.SortFunc(resp.Chat.Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + + require.NoError(t, testutil.ProtoEqual(expectedMeta, resp.Chat)) + + for _, u := range []*common.Account{userA, userB} { + getChats := &chatpb.GetChatsRequest{Owner: u.ToProto()} + getChats.Signature = signProtoMessage(t, getChats, u, false) + + chats, err := env.client.GetChats(ctx, getChats) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.Len(t, chats.Chats, 1) + + slices.SortFunc(chats.Chats[0].Members, func(a, b *chatpb.Member) int { + return bytes.Compare(a.MemberId.Value, b.MemberId.Value) + }) + + require.NoError(t, testutil.ProtoEqual(resp.Chat, chats.Chats[0])) + } + }) + + var messages []*chatpb.Message + t.Run("Send Messages", func(t *testing.T) { + for _, u := range []*common.Account{userA, userB} { + for i := 0; i < 5; i++ { + req := &chatpb.SendMessageRequest{ + ChatId: chatId.ToProto(), + Owner: u.ToProto(), + Content: []*chatpb.Content{ + { + Type: &chatpb.Content_Text{ + Text: &chatpb.TextContent{ + Text: fmt.Sprintf("message-%d", i), + }, + }, + }, + }, + } + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.SendMessage(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.SendMessageResponse_OK, resp.Result) + messages = append(messages, resp.GetMessage()) + + // TODO: Hack on message generation...again. + time.Sleep(time.Millisecond) + } + } + + for _, u := range []*common.Account{userA, userB} { + req := &chatpb.GetChatsRequest{Owner: u.ToProto()} + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.GetChats(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, resp.Result) + + // 5 unread _each_ + require.EqualValues(t, 5, resp.Chats[0].NumUnread) + } + }) + + t.Run("Get Messages", func(t *testing.T) { + for _, u := range []*common.Account{userA, userB} { + req := &chatpb.GetMessagesRequest{ + ChatId: chatId.ToProto(), + Owner: u.ToProto(), + } + req.Signature = signProtoMessage(t, req, u, false) + + resp, err := env.client.GetMessages(ctx, req) + require.NoError(t, err) + require.NoError(t, testutil.ProtoSliceEqual(messages, resp.GetMessages())) + + req.Cursor = resp.Messages[1].GetCursor() + req.Signature = nil + req.Signature = signProtoMessage(t, req, u, false) + + resp, err = env.client.GetMessages(ctx, req) + require.NoError(t, err) + require.NoError(t, testutil.ProtoSliceEqual(messages[2:], resp.GetMessages())) + } + }) + + t.Run("Advance Pointer", func(t *testing.T) { + for _, tc := range []struct { + offset int + user *common.Account + }{ + {offset: 5 + 2, user: userA}, + {offset: 0 + 2, user: userB}, + } { + req := &chatpb.AdvancePointerRequest{ + ChatId: chatId.ToProto(), + Pointer: &chatpb.Pointer{ + Type: chatpb.PointerType_READ, + Value: messages[tc.offset].MessageId, + MemberId: tc.user.MustToChatMemberId().ToProto(), + }, + Owner: tc.user.ToProto(), + } + req.Signature = signProtoMessage(t, req, tc.user, false) + + resp, err := env.client.AdvancePointer(ctx, req) + require.NoError(t, err) + require.Equal(t, chatpb.AdvancePointerResponse_OK, resp.Result) + + getChats := &chatpb.GetChatsRequest{Owner: tc.user.ToProto()} + getChats.Signature = signProtoMessage(t, getChats, tc.user, false) + + chats, err := env.client.GetChats(ctx, getChats) + require.NoError(t, err) + require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) + require.EqualValues(t, 2, chats.Chats[0].NumUnread) + } + }) + + t.Run("Stream", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + client, err := env.client.StreamChatEvents(ctx) + require.NoError(t, err) + + req := &chatpb.OpenChatEventStream{ + ChatId: chatId.ToProto(), + Owner: userA.ToProto(), + Signature: nil, + } + req.Signature = signProtoMessage(t, req, userA, false) + + err = client.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_OpenStream{ + OpenStream: req, + }, + }) + require.NoError(t, err) + + // expect some amount of flushes + var streamedMessages []*chatpb.Message + for { + resp, err := client.Recv() + require.NoError(t, err) + + switch typed := resp.Type.(type) { + case *chatpb.StreamChatEventsResponse_Error: + require.FailNow(t, typed.Error.String()) + case *chatpb.StreamChatEventsResponse_Ping: + _ = client.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_Pong{ + Pong: &commonpb.ClientPong{Timestamp: timestamppb.Now()}, + }, + }) + + case *chatpb.StreamChatEventsResponse_Events: + for _, e := range typed.Events.Events { + if m := e.GetMessage(); m != nil { + streamedMessages = append(streamedMessages, m) + if len(streamedMessages) == len(messages) { + break + } + } + } + + default: + } + + if len(streamedMessages) == len(messages) { + break + } + } + + require.True(t, slices.IsSortedFunc(streamedMessages, func(a, b *chatpb.Message) int { + return -1 * bytes.Compare(a.MessageId.Value, b.MessageId.Value) + })) + slices.Reverse(streamedMessages) + require.NoError(t, testutil.ProtoSliceEqual(messages, streamedMessages)) + }) +} + +type testEnv struct { + ctx context.Context + client chatpb.ChatClient + server *Server + data data.Provider +} + +func setup(t *testing.T) (env *testEnv, cleanup func()) { + conn, serv, err := testutil.NewServer() + require.NoError(t, err) + + env = &testEnv{ + ctx: context.Background(), + client: chatpb.NewChatClient(conn), + data: data.NewTestDataProvider(), + } + + env.server = NewChatServer( + env.data, + auth_util.NewRPCSignatureVerifier(env.data), + pushmemory.NewPushProvider(), + ) + + serv.RegisterService(func(server *grpc.Server) { + chatpb.RegisterChatServer(server, env.server) + }) + + testutil.SetupRandomSubsidizer(t, env.data) + + cleanup, err = serv.Serve() + require.NoError(t, err) + return env, cleanup +} + +func signProtoMessage(t *testing.T, msg proto.Message, signer *common.Account, simulateInvalidSignature bool) *commonpb.Signature { + msgBytes, err := proto.Marshal(msg) + require.NoError(t, err) + + if simulateInvalidSignature { + signer = testutil.NewRandomAccount(t) + } + + signature, err := signer.Sign(msgBytes) + require.NoError(t, err) + + return &commonpb.Signature{ + Value: signature, + } +} diff --git a/pkg/code/server/grpc/transaction/v2/intent_handler.go b/pkg/code/server/grpc/transaction/v2/intent_handler.go index b54cb7b8..85b8d4eb 100644 --- a/pkg/code/server/grpc/transaction/v2/intent_handler.go +++ b/pkg/code/server/grpc/transaction/v2/intent_handler.go @@ -475,6 +475,7 @@ func (h *SendPrivatePaymentIntentHandler) PopulateMetadata(ctx context.Context, IsRemoteSend: typedProtoMetadata.IsRemoteSend, IsMicroPayment: isMicroPayment, IsTip: typedProtoMetadata.IsTip, + IsChat: typedProtoMetadata.IsChat, } if typedProtoMetadata.IsTip { @@ -488,6 +489,14 @@ func (h *SendPrivatePaymentIntentHandler) PopulateMetadata(ctx context.Context, } } + if typedProtoMetadata.IsChat { + if typedProtoMetadata.ChatId == nil { + return newIntentValidationError("chat id is missing") + } + + intentRecord.SendPrivatePaymentMetadata.ChatId = base58.Encode(typedProtoMetadata.ChatId.GetValue()) + } + if destinationAccountInfo != nil { intentRecord.SendPrivatePaymentMetadata.DestinationOwnerAccount = destinationAccountInfo.OwnerAccount } diff --git a/pkg/code/server/grpc/transaction/v2/testutil.go b/pkg/code/server/grpc/transaction/v2/testutil.go index a160c2db..106feb6a 100644 --- a/pkg/code/server/grpc/transaction/v2/testutil.go +++ b/pkg/code/server/grpc/transaction/v2/testutil.go @@ -25,7 +25,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v1" - chatv2pb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" messagingpb "github.com/code-payments/code-protobuf-api/generated/go/messaging/v1" transactionpb "github.com/code-payments/code-protobuf-api/generated/go/transaction/v2" @@ -38,7 +37,6 @@ import ( "github.com/code-payments/code-server/pkg/code/data/account" "github.com/code-payments/code-server/pkg/code/data/action" chat_v1 "github.com/code-payments/code-server/pkg/code/data/chat/v1" - chat_v2 "github.com/code-payments/code-server/pkg/code/data/chat/v2" "github.com/code-payments/code-server/pkg/code/data/commitment" "github.com/code-payments/code-server/pkg/code/data/currency" "github.com/code-payments/code-server/pkg/code/data/deposit" @@ -6182,9 +6180,3 @@ func getProtoChatMessage(t *testing.T, record *chat_v1.Message) *chatpb.ChatMess require.NoError(t, proto.Unmarshal(record.Data, &protoMessage)) return &protoMessage } - -func getProtoChatMessageV2(t *testing.T, record *chat_v2.MessageRecord) *chatv2pb.ChatMessage { - protoMessage := &chatv2pb.ChatMessage{} - require.NoError(t, proto.Unmarshal(record.Data, protoMessage)) - return protoMessage -} diff --git a/pkg/code/server/grpc/user/server.go b/pkg/code/server/grpc/user/server.go index c4f536cf..35469244 100644 --- a/pkg/code/server/grpc/user/server.go +++ b/pkg/code/server/grpc/user/server.go @@ -705,7 +705,19 @@ func (s *identityServer) GetTwitterUser(ctx context.Context, req *userpb.GetTwit var friendChatId *commonpb.ChatId if req.Requestor != nil { // TODO: Validate the requestor - friendChatId = chat.GetChatId(base58.Encode(req.Requestor.Value), tipAddress.PublicKey().ToBase58(), true).ToProto() + ownerAccount, err := common.NewAccountFromProto(req.Requestor) + if err != nil { + log.WithError(err).Warn("failed to get owner account") + return nil, status.Error(codes.Internal, "") + } + + ownerMessagingAccount, err := ownerAccount.ToMessagingAccount(common.KinMintAccount) + if err != nil { + log.WithError(err).Warn("failed to get owner messaging account") + return nil, status.Error(codes.Internal, "") + } + + friendChatId = chat.GetTwoWayChatId(ownerMessagingAccount.PublicKey().ToBytes(), tipAddress.PublicKey().ToBytes()).ToProto() } return &userpb.GetTwitterUserResponse{ diff --git a/pkg/database/query/cursor.go b/pkg/database/query/cursor.go index 533ef49f..08e29085 100644 --- a/pkg/database/query/cursor.go +++ b/pkg/database/query/cursor.go @@ -9,7 +9,7 @@ import ( type Cursor []byte var ( - EmptyCursor Cursor = Cursor([]byte{}) + EmptyCursor = Cursor([]byte{}) ) func ToCursor(val uint64) Cursor { diff --git a/pkg/pointer/pointer.go b/pkg/pointer/pointer.go index a353d347..e32fbf22 100644 --- a/pkg/pointer/pointer.go +++ b/pkg/pointer/pointer.go @@ -7,6 +7,15 @@ func String(value string) *string { return &value } +// StringOrEmpty returns the value of the string, if set. Otherwise, "". +func StringOrEmpty(value *string) string { + if value != nil { + return *value + } + + return "" +} + // StringOrDefault returns the pointer if not nil, otherwise the default value func StringOrDefault(value *string, defaultValue string) *string { if value != nil { diff --git a/pkg/testutil/proto.go b/pkg/testutil/proto.go new file mode 100644 index 00000000..a9ba5dce --- /dev/null +++ b/pkg/testutil/proto.go @@ -0,0 +1,33 @@ +package testutil + +import ( + "fmt" + + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +func ProtoEqual(a, b proto.Message) error { + if proto.Equal(a, b) { + return nil + } + + aJSON, _ := protojson.Marshal(a) + bJSON, _ := protojson.Marshal(b) + + return fmt.Errorf("expected: %v\nactual: %v", string(aJSON), string(bJSON)) +} + +func ProtoSliceEqual[T proto.Message](a, b []T) error { + if len(a) != len(b) { + return fmt.Errorf("len(%d) != len(%d)", len(a), len(b)) + } + + for i := range a { + if err := ProtoEqual(a[i], b[i]); err != nil { + return fmt.Errorf("element mismatch at %d\n%w", i, err) + } + } + + return nil +} From 0bd5d153552369942ee2c45f6d13e52595069b28 Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Thu, 3 Oct 2024 11:09:24 -0400 Subject: [PATCH 70/71] chat: add is_self back to chat metadata --- go.mod | 2 +- go.sum | 4 ++-- pkg/code/server/grpc/chat/v2/server.go | 6 ++++++ pkg/code/server/grpc/chat/v2/server_test.go | 5 +++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 2463acd3..527d2967 100644 --- a/go.mod +++ b/go.mod @@ -132,4 +132,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8 +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa diff --git a/go.sum b/go.sum index 16246996..75a224a2 100644 --- a/go.sum +++ b/go.sum @@ -423,8 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8 h1:cP0i0oAMtWtyBP0wMOuVOzg2i3dYQZOuq2CtXrgr8iM= -github.com/mfycheng/code-protobuf-api v0.0.0-20240930161350-0d6798fdd5b8/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa h1:mL+mTGMgq3bFdJ0Z4Mu4m5bL2Lg0LRR6WehBBvYllwE= +github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 235b5265..86d986f1 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -859,6 +859,10 @@ func (s *Server) NotifyIsTyping(ctx context.Context, req *chatpb.NotifyIsTypingR return &chatpb.NotifyIsTypingResponse{}, nil } +func (s *Server) NotifyMessage(_ context.Context, _ chat.ChatId, _ *chatpb.Message) { + // TODO: Cleanup this up +} + // todo: needs to have a 'fill' version func (s *Server) getMetadata(ctx context.Context, asMember chat.MemberId, id chat.ChatId) (*chatpb.Metadata, error) { mdRecord, err := s.data.GetChatMetadata(ctx, id) @@ -936,6 +940,8 @@ func (s *Server) populateMetadata(ctx context.Context, mdRecord *chat.MetadataRe continue } + member.IsSelf = true + // TODO: Do we actually want to compute this feature? It's maybe non-trivial. // Maybe should have a safety valve at minimum. md.NumUnread, err = s.data.GetChatUnreadCountV2(ctx, mdRecord.ChatId, memberId, m.ReadPointer) diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go index 951f5a21..3597a4ca 100644 --- a/pkg/code/server/grpc/chat/v2/server_test.go +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -126,6 +126,7 @@ func TestServerHappy(t *testing.T) { DisplayName: "name-a", ProfilePicUrl: "pp-a", }, + IsSelf: true, }, { MemberId: userB.MustToChatMemberId().ToProto(), @@ -152,6 +153,10 @@ func TestServerHappy(t *testing.T) { getChats := &chatpb.GetChatsRequest{Owner: u.ToProto()} getChats.Signature = signProtoMessage(t, getChats, u, false) + for _, member := range resp.Chat.Members { + member.IsSelf = bytes.Equal(u.MustToChatMemberId(), member.MemberId.Value) + } + chats, err := env.client.GetChats(ctx, getChats) require.NoError(t, err) require.Equal(t, chatpb.GetChatsResponse_OK, chats.Result) From 1241ef93b57016321db084a951bdaeeeba25d54e Mon Sep 17 00:00:00 2001 From: Mike Cheng Date: Tue, 15 Oct 2024 01:33:51 -0400 Subject: [PATCH 71/71] chat/v2: add dual stream implementation (local server) --- go.mod | 2 +- go.sum | 4 +- pkg/code/server/grpc/chat/v2/server.go | 273 +--------- pkg/code/server/grpc/chat/v2/server_test.go | 105 +++- pkg/code/server/grpc/chat/v2/stream.go | 158 +++--- pkg/code/server/grpc/chat/v2/streams.go | 522 ++++++++++++++++++++ 6 files changed, 700 insertions(+), 364 deletions(-) create mode 100644 pkg/code/server/grpc/chat/v2/streams.go diff --git a/go.mod b/go.mod index 527d2967..5eff3863 100644 --- a/go.mod +++ b/go.mod @@ -132,4 +132,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa +replace github.com/code-payments/code-protobuf-api => github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d diff --git a/go.sum b/go.sum index 75a224a2..0db72051 100644 --- a/go.sum +++ b/go.sum @@ -423,8 +423,8 @@ github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Ky github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= -github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa h1:mL+mTGMgq3bFdJ0Z4Mu4m5bL2Lg0LRR6WehBBvYllwE= -github.com/mfycheng/code-protobuf-api v0.0.0-20241003150533-14516868a8aa/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= +github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d h1:pOwndvvkUvWXzoiJIIo5wiPT/IP67J5AJqF4sLPdKcY= +github.com/mfycheng/code-protobuf-api v0.0.0-20241010162320-5dac31db232d/go.mod h1:pHQm75vydD6Cm2qHAzlimW6drysm489Z4tVxC2zHSsU= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= diff --git a/pkg/code/server/grpc/chat/v2/server.go b/pkg/code/server/grpc/chat/v2/server.go index 86d986f1..a339b24d 100644 --- a/pkg/code/server/grpc/chat/v2/server.go +++ b/pkg/code/server/grpc/chat/v2/server.go @@ -18,11 +18,9 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" - commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" auth_util "github.com/code-payments/code-server/pkg/code/auth" "github.com/code-payments/code-server/pkg/code/common" code_data "github.com/code-payments/code-server/pkg/code/data" @@ -53,7 +51,7 @@ type Server struct { push push.Provider streamsMu sync.RWMutex - streams map[string]*chatEventStream + streams map[string]eventStream chatLocks *sync_util.StripedLock chatEventChans *sync_util.StripedChannel @@ -73,7 +71,7 @@ func NewChatServer( auth: auth, push: push, - streams: make(map[string]*chatEventStream), + streams: make(map[string]eventStream), chatLocks: sync_util.NewStripedLock(64), // todo: configurable parameters chatEventChans: sync_util.NewStripedChannel(64, 100_000), // todo: configurable parameters @@ -240,151 +238,6 @@ func (s *Server) GetMessages(ctx context.Context, req *chatpb.GetMessagesRequest }, nil } -func (s *Server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) error { - ctx := streamer.Context() - - log := s.log.WithField("method", "StreamChatEvents") - log = client.InjectLoggingMetadata(ctx, log) - - req, err := boundedStreamChatEventsRecv(ctx, streamer, 250*time.Millisecond) - if err != nil { - return err - } - - if req.GetOpenStream() == nil { - return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") - } - - owner, err := common.NewAccountFromProto(req.GetOpenStream().Owner) - if err != nil { - log.WithError(err).Warn("invalid owner account") - return status.Error(codes.Internal, "") - } - log = log.WithField("owner", owner.PublicKey().ToBase58()) - - chatId, err := chat.GetChatIdFromProto(req.GetOpenStream().ChatId) - if err != nil { - log.WithError(err).Warn("invalid chat id") - return status.Error(codes.Internal, "") - } - log = log.WithField("chat_id", chatId.String()) - - memberId, err := owner.ToChatMemberId() - if err != nil { - log.WithError(err).Warn("failed to derive messaging account") - return status.Error(codes.Internal, "") - } - log = log.WithField("member_id", memberId.String()) - - signature := req.GetOpenStream().Signature - req.GetOpenStream().Signature = nil - if err := s.auth.Authenticate(streamer.Context(), owner, req.GetOpenStream(), signature); err != nil { - return err - } - - isMember, err := s.data.IsChatMember(ctx, chatId, memberId) - if err != nil { - log.WithError(err).Warn("failed to derive messaging account") - return status.Error(codes.Internal, "") - } - if !isMember { - return streamer.Send(&chatpb.StreamChatEventsResponse{ - Type: &chatpb.StreamChatEventsResponse_Error{ - Error: &chatpb.ChatStreamEventError{Code: chatpb.ChatStreamEventError_DENIED}, - }, - }) - } - - streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) - - s.streamsMu.Lock() - - stream, exists := s.streams[streamKey] - if exists { - s.streamsMu.Unlock() - // There's an existing stream on this Server that must be terminated first. - // Warn to see how often this happens in practice - log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) - return status.Error(codes.Aborted, "stream already exists") - } - - stream = newChatEventStream(streamBufferSize) - - // The race detector complains when reading the stream pointer ref outside of the lock. - streamRef := fmt.Sprintf("%p", stream) - log.Tracef("setting up new stream (stream=%s)", streamRef) - s.streams[streamKey] = stream - - s.streamsMu.Unlock() - - defer func() { - s.streamsMu.Lock() - - log.Tracef("closing streamer (stream=%s)", streamRef) - - // We check to see if the current active stream is the one that we created. - // If it is, we can just remove it since it's closed. Otherwise, we leave it - // be, as another StreamChatEvents() call is handling it. - liveStream, exists := s.streams[streamKey] - if exists && liveStream == stream { - delete(s.streams, streamKey) - } - - s.streamsMu.Unlock() - }() - - sendPingCh := time.After(0) - streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer) - - // todo: We should also "flush" pointers for each chat member - go s.flushMessages(ctx, chatId, owner, stream) - - for { - select { - case event, ok := <-stream.streamCh: - if !ok { - log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) - return status.Error(codes.Aborted, "stream closed") - } - - err := streamer.Send(&chatpb.StreamChatEventsResponse{ - Type: &chatpb.StreamChatEventsResponse_Events{ - Events: &chatpb.ChatStreamEventBatch{ - Events: []*chatpb.ChatStreamEvent{event}, - }, - }, - }) - if err != nil { - log.WithError(err).Info("failed to forward chat message") - return err - } - case <-sendPingCh: - log.Tracef("sending ping to client (stream=%s)", streamRef) - - sendPingCh = time.After(streamPingDelay) - - err := streamer.Send(&chatpb.StreamChatEventsResponse{ - Type: &chatpb.StreamChatEventsResponse_Ping{ - Ping: &commonpb.ServerPing{ - Timestamp: timestamppb.Now(), - PingDelay: durationpb.New(streamPingDelay), - }, - }, - }) - if err != nil { - log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) - return status.Error(codes.Aborted, "terminating unhealthy stream") - } - case <-streamHealthCh: - log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) - return status.Error(codes.Aborted, "terminating unhealthy stream") - case <-ctx.Done(): - log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) - return status.Error(codes.Canceled, "") - } - } -} - func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (*chatpb.StartChatResponse, error) { log := s.log.WithField("method", "StartChat") log = client.InjectLoggingMetadata(ctx, log) @@ -454,9 +307,7 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* switch intentRecord.State { case intent.StatePending: - return &chatpb.StartChatResponse{ - Result: chatpb.StartChatResponse_PENDING, - }, nil + log.Info("Payment intent is pending") case intent.StateConfirmed: default: log.WithField("state", intentRecord.State).Info("PayToChat intent did not succeed") @@ -571,6 +422,14 @@ func (s *Server) StartChat(ctx context.Context, req *chatpb.StartChatRequest) (* return nil, status.Error(codes.Internal, "") } + event := &chatEventNotification{ + chatId: chatId, + chatUpdate: md, + } + if err = s.asyncNotifyAll(chatId, event); err != nil { + log.WithError(err).Warn("failed to notify event stream") + } + return &chatpb.StartChatResponse{ Result: chatpb.StartChatResponse_OK, Chat: md, @@ -735,10 +594,9 @@ func (s *Server) AdvancePointer(ctx context.Context, req *chatpb.AdvancePointerR } if isAdvanced { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Pointer{ - Pointer: req.Pointer, - }, + event := &chatEventNotification{ + chatId: chatId, + pointerUpdate: req.Pointer, } if err := s.asyncNotifyAll(chatId, event); err != nil { log.WithError(err).Warn("failure notifying chat event") @@ -843,12 +701,15 @@ func (s *Server) NotifyIsTyping(ctx context.Context, req *chatpb.NotifyIsTypingR return &chatpb.NotifyIsTypingResponse{Result: chatpb.NotifyIsTypingResponse_DENIED}, nil } - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_IsTyping{ - IsTyping: &chatpb.IsTyping{ - MemberId: memberId.ToProto(), - IsTyping: req.IsTyping, - }, + // Internalize the event + // notifyAll sends to both(depending on type) + // notifyAll then determines the actual assembly + + event := &chatEventNotification{ + chatId: chatId, + isTyping: &chatpb.IsTyping{ + MemberId: memberId.ToProto(), + IsTyping: req.IsTyping, }, } @@ -952,87 +813,6 @@ func (s *Server) populateMetadata(ctx context.Context, mdRecord *chat.MetadataRe return md, nil } - -func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream *chatEventStream) { - log := s.log.WithFields(logrus.Fields{ - "method": "flushMessages", - "chat_id": chatId.String(), - "owner_account": owner.PublicKey().ToBase58(), - }) - - protoChatMessages, err := s.getProtoChatMessages( - ctx, - chatId, - owner, - query.WithCursor(query.EmptyCursor), - query.WithDirection(query.Descending), - query.WithLimit(flushMessageCount), - ) - if err != nil { - log.WithError(err).Warn("failure getting chat messages") - return - } - - for _, protoChatMessage := range protoChatMessages { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: protoChatMessage, - }, - } - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - return - } - } -} - -func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) { - log := s.log.WithFields(logrus.Fields{ - "method": "flushPointers", - "chat_id": chatId.String(), - }) - - memberRecords, err := s.data.GetChatMembersV2(ctx, chatId) - if err != nil { - log.WithError(err).Warn("failure getting chat members") - return - } - - for _, memberRecord := range memberRecords { - for _, optionalPointer := range []struct { - kind chat.PointerType - value *chat.MessageId - }{ - {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, - {chat.PointerTypeRead, memberRecord.ReadPointer}, - } { - if optionalPointer.value == nil { - continue - } - - memberId, err := chat.GetMemberIdFromString(memberRecord.MemberId) - if err != nil { - log.WithError(err).Warnf("failure getting memberId from %s", memberRecord.MemberId) - return - } - - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Pointer{ - Pointer: &chatpb.Pointer{ - Type: optionalPointer.kind.ToProto(), - Value: optionalPointer.value.ToProto(), - MemberId: memberId.ToProto(), - }, - }, - } - if err := stream.notify(event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - return - } - } - } -} - func (s *Server) getProtoChatMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, queryOptions ...query.Option) ([]*chatpb.Message, error) { messageRecords, err := s.data.GetAllChatMessagesV2(ctx, chatId, queryOptions...) if err != nil { @@ -1133,10 +913,9 @@ func (s *Server) persistChatMessage(ctx context.Context, chatId chat.ChatId, pro } func (s *Server) onPersistChatMessage(log *logrus.Entry, chatId chat.ChatId, chatMessage *chatpb.Message) { - event := &chatpb.ChatStreamEvent{ - Type: &chatpb.ChatStreamEvent_Message{ - Message: chatMessage, - }, + event := &chatEventNotification{ + chatId: chatId, + messageUpdate: chatMessage, } if err := s.asyncNotifyAll(chatId, event); err != nil { diff --git a/pkg/code/server/grpc/chat/v2/server_test.go b/pkg/code/server/grpc/chat/v2/server_test.go index 3597a4ca..edfc8311 100644 --- a/pkg/code/server/grpc/chat/v2/server_test.go +++ b/pkg/code/server/grpc/chat/v2/server_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/sirupsen/logrus" "slices" "testing" "time" @@ -95,6 +96,51 @@ func TestServerHappy(t *testing.T) { require.Empty(t, chats.Chats) }) + eventCtx, eventCancel := context.WithTimeout(ctx, time.Minute) + defer eventCancel() + + eventClient, err := env.client.StreamChatEvents(eventCtx) + require.NoError(t, err) + + req := &chatpb.StreamChatEventsRequest_Params{ + Owner: userA.ToProto(), + } + req.Signature = signProtoMessage(t, req, userA, false) + + err = eventClient.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_Params_{ + Params: req, + }, + }) + eventCh := make(chan *chatpb.StreamChatEventsResponse_EventBatch, 1024) + + go func() { + defer close(eventCh) + + for { + msg, err := eventClient.Recv() + if err != nil { + env.log.WithError(err).Error("Failed to receive event stream") + return + } + + switch typed := msg.Type.(type) { + case *chatpb.StreamChatEventsResponse_Ping: + _ = eventClient.Send(&chatpb.StreamChatEventsRequest{ + Type: &chatpb.StreamChatEventsRequest_Pong{ + Pong: &commonpb.ClientPong{ + Timestamp: timestamppb.Now(), + }, + }, + }) + case *chatpb.StreamChatEventsResponse_Error: + env.log.WithError(err).WithField("code", typed.Error.Code).Warn("failed to receive update event") + case *chatpb.StreamChatEventsResponse_Events: + eventCh <- typed.Events + } + } + }() + t.Run("StartChat", func(t *testing.T) { req := &chatpb.StartChatRequest{ Owner: userA.ToProto(), @@ -267,23 +313,48 @@ func TestServerHappy(t *testing.T) { } }) - t.Run("Stream", func(t *testing.T) { + eventCancel() + t.Run("Event Stream", func(t *testing.T) { + var events []*chatpb.StreamChatEventsResponse_ChatUpdate + for batch := range eventCh { + for _, e := range batch.Updates { + events = append(events, e) + } + } + + require.Equal(t, 13, len(events)) + + // Chat creation + require.NotNil(t, events[0].Metadata) + + // 10 messages + for i := 1; i < 10+1; i++ { + require.NotNil(t, events[i].LastMessage) + } + + // Pointer updates + for i := 1 + 10; i < 13; i++ { + require.NotNil(t, events[i].Pointer) + } + }) + + t.Run("Message Stream", func(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - client, err := env.client.StreamChatEvents(ctx) + client, err := env.client.StreamMessages(ctx) require.NoError(t, err) - req := &chatpb.OpenChatEventStream{ + req := &chatpb.StreamMessagesRequest_Params{ ChatId: chatId.ToProto(), Owner: userA.ToProto(), Signature: nil, } req.Signature = signProtoMessage(t, req, userA, false) - err = client.Send(&chatpb.StreamChatEventsRequest{ - Type: &chatpb.StreamChatEventsRequest_OpenStream{ - OpenStream: req, + err = client.Send(&chatpb.StreamMessagesRequest{ + Type: &chatpb.StreamMessagesRequest_Params_{ + Params: req, }, }) require.NoError(t, err) @@ -295,22 +366,20 @@ func TestServerHappy(t *testing.T) { require.NoError(t, err) switch typed := resp.Type.(type) { - case *chatpb.StreamChatEventsResponse_Error: + case *chatpb.StreamMessagesResponse_Error: require.FailNow(t, typed.Error.String()) - case *chatpb.StreamChatEventsResponse_Ping: - _ = client.Send(&chatpb.StreamChatEventsRequest{ - Type: &chatpb.StreamChatEventsRequest_Pong{ + case *chatpb.StreamMessagesResponse_Ping: + _ = client.Send(&chatpb.StreamMessagesRequest{ + Type: &chatpb.StreamMessagesRequest_Pong{ Pong: &commonpb.ClientPong{Timestamp: timestamppb.Now()}, }, }) - case *chatpb.StreamChatEventsResponse_Events: - for _, e := range typed.Events.Events { - if m := e.GetMessage(); m != nil { - streamedMessages = append(streamedMessages, m) - if len(streamedMessages) == len(messages) { - break - } + case *chatpb.StreamMessagesResponse_Messages: + for _, m := range typed.Messages.Messages { + streamedMessages = append(streamedMessages, m) + if len(streamedMessages) == len(messages) { + break } } @@ -331,6 +400,7 @@ func TestServerHappy(t *testing.T) { } type testEnv struct { + log *logrus.Logger ctx context.Context client chatpb.ChatClient server *Server @@ -342,6 +412,7 @@ func setup(t *testing.T) (env *testEnv, cleanup func()) { require.NoError(t, err) env = &testEnv{ + log: logrus.StandardLogger(), ctx: context.Background(), client: chatpb.NewChatClient(conn), data: data.NewTestDataProvider(), diff --git a/pkg/code/server/grpc/chat/v2/stream.go b/pkg/code/server/grpc/chat/v2/stream.go index a5fa2017..ee2b859a 100644 --- a/pkg/code/server/grpc/chat/v2/stream.go +++ b/pkg/code/server/grpc/chat/v2/stream.go @@ -2,18 +2,15 @@ package chat_v2 import ( "context" - "strings" + "errors" "sync" "time" - "github.com/pkg/errors" "github.com/sirupsen/logrus" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" - - chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" - chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" ) const ( @@ -24,148 +21,115 @@ const ( streamNotifyTimeout = time.Second ) -type chatEventStream struct { +type eventStream interface { + notify(notification *chatEventNotification, timeout time.Duration) error +} + +type protoEventStream[T proto.Message] struct { sync.Mutex - closed bool - streamCh chan *chatpb.ChatStreamEvent + closed bool + ch chan T + transform func(*chatEventNotification) (T, bool) } -func newChatEventStream(bufferSize int) *chatEventStream { - return &chatEventStream{ - streamCh: make(chan *chatpb.ChatStreamEvent, bufferSize), +func newEventStream[T proto.Message]( + bufferSize int, + selector func(notification *chatEventNotification) (T, bool), +) *protoEventStream[T] { + return &protoEventStream[T]{ + ch: make(chan T, bufferSize), + transform: selector, } } -func (s *chatEventStream) notify(event *chatpb.ChatStreamEvent, timeout time.Duration) error { - m := proto.Clone(event).(*chatpb.ChatStreamEvent) - - s.Lock() +func (e *protoEventStream[T]) notify(event *chatEventNotification, timeout time.Duration) error { + msg, ok := e.transform(event) + if !ok { + return nil + } - if s.closed { - s.Unlock() + e.Lock() + if e.closed { + e.Unlock() return errors.New("cannot notify closed stream") } select { - case s.streamCh <- m: + case e.ch <- msg: case <-time.After(timeout): - s.Unlock() - s.close() + e.Unlock() + e.close() return errors.New("timed out sending message to streamCh") } - s.Unlock() + e.Unlock() return nil } -func (s *chatEventStream) close() { - s.Lock() - defer s.Unlock() +func (e *protoEventStream[T]) close() { + e.Lock() + defer e.Unlock() - if s.closed { + if e.closed { return } - s.closed = true - close(s.streamCh) + e.closed = true + close(e.ch) +} + +type ptr[T any] interface { + proto.Message + *T } -func boundedStreamChatEventsRecv( +func boundedReceive[Req any, ReqPtr ptr[Req]]( ctx context.Context, - streamer chatpb.Chat_StreamChatEventsServer, + stream grpc.ServerStream, timeout time.Duration, -) (req *chatpb.StreamChatEventsRequest, err error) { - done := make(chan struct{}) +) (ReqPtr, error) { + var err error + var req = new(Req) + doneCh := make(chan struct{}) + go func() { - req, err = streamer.Recv() - close(done) + err = stream.RecvMsg(req) + close(doneCh) }() select { - case <-done: + case <-doneCh: return req, err case <-ctx.Done(): - return nil, status.Error(codes.Canceled, "") + return req, status.Error(codes.Canceled, "") case <-time.After(timeout): - return nil, status.Error(codes.DeadlineExceeded, "timed out receiving message") - } -} - -type chatEventNotification struct { - chatId chat.ChatId - event *chatpb.ChatStreamEvent - ts time.Time -} - -func (s *Server) asyncNotifyAll(chatId chat.ChatId, event *chatpb.ChatStreamEvent) error { - m := proto.Clone(event).(*chatpb.ChatStreamEvent) - ok := s.chatEventChans.Send(chatId[:], &chatEventNotification{chatId, m, time.Now()}) - if !ok { - return errors.New("chat event channel is full") - } - return nil -} - -func (s *Server) asyncChatEventStreamNotifier(workerId int, channel <-chan interface{}) { - log := s.log.WithFields(logrus.Fields{ - "method": "asyncChatEventStreamNotifier", - "worker": workerId, - }) - - for value := range channel { - typedValue, ok := value.(*chatEventNotification) - if !ok { - log.Warn("channel did not receive expected struct") - continue - } - - log := log.WithField("chat_id", typedValue.chatId.String()) - - if time.Since(typedValue.ts) > time.Second { - log.Warn("channel notification latency is elevated") - } - - s.streamsMu.RLock() - for key, stream := range s.streams { - if !strings.HasPrefix(key, typedValue.chatId.String()) { - continue - } - - if err := stream.notify(typedValue.event, streamNotifyTimeout); err != nil { - log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) - } - } - s.streamsMu.RUnlock() + return req, status.Error(codes.DeadlineExceeded, "timeout receiving message") } } -// Very naive implementation to start -func monitorChatEventStreamHealth( +func monitorStreamHealth[Req any, ReqPtr ptr[Req]]( ctx context.Context, log *logrus.Entry, ssRef string, - streamer chatpb.Chat_StreamChatEventsServer, + streamer grpc.ServerStream, + validFn func(ReqPtr) bool, ) <-chan struct{} { - streamHealthChan := make(chan struct{}) + healthCh := make(chan struct{}) go func() { - defer close(streamHealthChan) + defer close(healthCh) for { - // todo: configurable timeout - req, err := boundedStreamChatEventsRecv(ctx, streamer, streamKeepAliveRecvTimeout) + req, err := boundedReceive[Req, ReqPtr](ctx, streamer, streamKeepAliveRecvTimeout) if err != nil { return } - switch req.Type.(type) { - case *chatpb.StreamChatEventsRequest_Pong: - log.Tracef("received pong from client (stream=%s)", ssRef) - default: - // Client sent something unexpected. Terminate the stream + if !validFn(req) { return } + log.Tracef("received pong from client (stream=%s)", ssRef) } }() - return streamHealthChan + return healthCh } diff --git a/pkg/code/server/grpc/chat/v2/streams.go b/pkg/code/server/grpc/chat/v2/streams.go new file mode 100644 index 00000000..5043e54b --- /dev/null +++ b/pkg/code/server/grpc/chat/v2/streams.go @@ -0,0 +1,522 @@ +package chat_v2 + +import ( + "bytes" + "context" + "errors" + "fmt" + "go.uber.org/zap" + "time" + + "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" + + chatpb "github.com/code-payments/code-protobuf-api/generated/go/chat/v2" + commonpb "github.com/code-payments/code-protobuf-api/generated/go/common/v1" + + "github.com/code-payments/code-server/pkg/code/common" + chat "github.com/code-payments/code-server/pkg/code/data/chat/v2" + "github.com/code-payments/code-server/pkg/database/query" + "github.com/code-payments/code-server/pkg/grpc/client" +) + +func (s *Server) StreamMessages(stream chatpb.Chat_StreamMessagesServer) error { + ctx := stream.Context() + log := s.log.WithField("method", "StreamMessages") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedReceive[chatpb.StreamMessagesRequest, *chatpb.StreamMessagesRequest]( + ctx, + stream, + 250*time.Millisecond, + ) + if err != nil { + return err + } + + if req.GetParams() == nil { + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") + } + + owner, err := common.NewAccountFromProto(req.GetParams().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + chatId, err := chat.GetChatIdFromProto(req.GetParams().ChatId) + if err != nil { + log.WithError(err).Warn("invalid chat id") + return status.Error(codes.Internal, "") + } + log = log.WithField("chat_id", chatId.String()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.GetParams().Signature + req.GetParams().Signature = nil + if err := s.auth.Authenticate(stream.Context(), owner, req.GetParams(), signature); err != nil { + return err + } + + isMember, err := s.data.IsChatMember(ctx, chatId, memberId) + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + if !isMember { + return stream.Send(&chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Error{ + Error: &chatpb.StreamError{Code: chatpb.StreamError_DENIED}, + }, + }) + } + + streamKey := fmt.Sprintf("%s:%s", chatId.String(), memberId.String()) + + s.streamsMu.Lock() + + if _, exists := s.streams[streamKey]; exists { + s.streamsMu.Unlock() + // There's an existing stream on this Server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + ss := newEventStream[*chatpb.StreamMessagesResponse_MessageBatch]( + streamBufferSize, + func(notification *chatEventNotification) (*chatpb.StreamMessagesResponse_MessageBatch, bool) { + if notification.messageUpdate == nil { + return nil, false + } + if notification.chatId != chatId { + return nil, false + } + + return &chatpb.StreamMessagesResponse_MessageBatch{ + Messages: []*chatpb.Message{notification.messageUpdate}, + }, true + }, + ) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = ss + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == ss { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorStreamHealth(ctx, log, streamRef, stream, func(t *chatpb.StreamMessagesRequest) bool { + return t.GetPong() != nil + }) + + // TODO: Support pagination options (or just remove if not necessary). + go s.flushMessages(ctx, chatId, owner, ss) + go s.flushPointers(ctx, chatId, owner, ss) + + for { + select { + case batch, ok := <-ss.ch: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + resp := &chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Messages{ + Messages: batch, + }, + } + + if err = stream.Send(resp); err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := stream.Send(&chatpb.StreamMessagesResponse{ + Type: &chatpb.StreamMessagesResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *Server) StreamChatEvents(stream chatpb.Chat_StreamChatEventsServer) error { + ctx := stream.Context() + log := s.log.WithField("method", "StreamChatEvents") + log = client.InjectLoggingMetadata(ctx, log) + + req, err := boundedReceive[chatpb.StreamChatEventsRequest, *chatpb.StreamChatEventsRequest]( + ctx, + stream, + 250*time.Millisecond, + ) + if err != nil { + return err + } + + if req.GetParams() == nil { + return status.Error(codes.InvalidArgument, "StreamChatEventsRequest.Type must be OpenStreamRequest") + } + + owner, err := common.NewAccountFromProto(req.GetParams().Owner) + if err != nil { + log.WithError(err).Warn("invalid owner account") + return status.Error(codes.Internal, "") + } + log = log.WithField("owner", owner.PublicKey().ToBase58()) + + memberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failed to derive messaging account") + return status.Error(codes.Internal, "") + } + log = log.WithField("member_id", memberId.String()) + + signature := req.GetParams().Signature + req.GetParams().Signature = nil + if err := s.auth.Authenticate(stream.Context(), owner, req.GetParams(), signature); err != nil { + return err + } + + // This should be safe? The user would have to provide a pub key + // that derives to a collision on another stream key (i.e. messages) + streamKey := fmt.Sprintf("%s", memberId.String()) + + s.streamsMu.Lock() + + if _, exists := s.streams[streamKey]; exists { + s.streamsMu.Unlock() + // There's an existing stream on this Server that must be terminated first. + // Warn to see how often this happens in practice + log.Warnf("existing stream detected on this Server (stream=%p) ; aborting", stream) + return status.Error(codes.Aborted, "stream already exists") + } + + ss := newEventStream[*chatpb.StreamChatEventsResponse_EventBatch]( + streamBufferSize, + func(notification *chatEventNotification) (*chatpb.StreamChatEventsResponse_EventBatch, bool) { + // We need to check memberships here. + // + // TODO: This needs to be heavily cached + isMember, err := s.data.IsChatMember(ctx, notification.chatId, memberId) + if err != nil { + log.Warn("Failed to check if member for notification", zap.String("chat_id", notification.chatId.String())) + } else if !isMember { + log.Debug("Notification for chat not a member of, dropping", zap.String("chat_id", notification.chatId.String())) + return nil, false + } + + return &chatpb.StreamChatEventsResponse_EventBatch{ + Updates: []*chatpb.StreamChatEventsResponse_ChatUpdate{ + { + ChatId: notification.chatId.ToProto(), + Metadata: notification.chatUpdate, + LastMessage: notification.messageUpdate, + Pointer: notification.pointerUpdate, + IsTyping: notification.isTyping, + }, + }, + }, true + }, + ) + + // The race detector complains when reading the stream pointer ref outside of the lock. + streamRef := fmt.Sprintf("%p", stream) + log.Tracef("setting up new stream (stream=%s)", streamRef) + s.streams[streamKey] = ss + + s.streamsMu.Unlock() + + defer func() { + s.streamsMu.Lock() + + log.Tracef("closing streamer (stream=%s)", streamRef) + + // We check to see if the current active stream is the one that we created. + // If it is, we can just remove it since it's closed. Otherwise, we leave it + // be, as another StreamChatEvents() call is handling it. + liveStream, exists := s.streams[streamKey] + if exists && liveStream == ss { + delete(s.streams, streamKey) + } + + s.streamsMu.Unlock() + }() + + sendPingCh := time.After(0) + streamHealthCh := monitorStreamHealth(ctx, log, streamRef, stream, func(t *chatpb.StreamMessagesRequest) bool { + return t.GetPong() != nil + }) + + go s.flushChats(ctx, owner, memberId, ss) + + for { + select { + case batch, ok := <-ss.ch: + if !ok { + log.Tracef("stream closed ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Aborted, "stream closed") + } + + resp := &chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Events{ + Events: batch, + }, + } + + if err = stream.Send(resp); err != nil { + log.WithError(err).Info("failed to forward chat message") + return err + } + case <-sendPingCh: + log.Tracef("sending ping to client (stream=%s)", streamRef) + + sendPingCh = time.After(streamPingDelay) + + err := stream.Send(&chatpb.StreamChatEventsResponse{ + Type: &chatpb.StreamChatEventsResponse_Ping{ + Ping: &commonpb.ServerPing{ + Timestamp: timestamppb.Now(), + PingDelay: durationpb.New(streamPingDelay), + }, + }, + }) + if err != nil { + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + } + case <-streamHealthCh: + log.Tracef("stream is unhealthy ; aborting (stream=%s)", streamRef) + return status.Error(codes.Aborted, "terminating unhealthy stream") + case <-ctx.Done(): + log.Tracef("stream context cancelled ; ending stream (stream=%s)", streamRef) + return status.Error(codes.Canceled, "") + } + } +} + +func (s *Server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushMessages", + "chat_id": chatId.String(), + "owner_account": owner.PublicKey().ToBase58(), + }) + + protoChatMessages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithCursor(query.EmptyCursor), + query.WithDirection(query.Descending), + query.WithLimit(flushMessageCount), + ) + if err != nil { + log.WithError(err).Warn("failure getting chat messages") + return + } + + for _, protoChatMessage := range protoChatMessages { + event := &chatEventNotification{ + chatId: chatId, + messageUpdate: protoChatMessage, + } + + if err = stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } +} + +func (s *Server) flushPointers(ctx context.Context, chatId chat.ChatId, owner *common.Account, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushPointers", + "chat_id": chatId.String(), + }) + + callingMemberId, err := owner.ToChatMemberId() + if err != nil { + log.WithError(err).Warn("failure computing self") + return + } + + memberRecords, err := s.data.GetChatMembersV2(ctx, chatId) + if err != nil { + log.WithError(err).Warn("failure getting chat members") + return + } + + for _, memberRecord := range memberRecords { + for _, optionalPointer := range []struct { + kind chat.PointerType + value *chat.MessageId + }{ + {chat.PointerTypeDelivered, memberRecord.DeliveryPointer}, + {chat.PointerTypeRead, memberRecord.ReadPointer}, + } { + if optionalPointer.value == nil { + continue + } + + memberId, err := chat.GetMemberIdFromString(memberRecord.MemberId) + if err != nil { + log.WithError(err).Warnf("failure getting memberId from %s", memberRecord.MemberId) + return + } + + if bytes.Equal(memberId, callingMemberId) { + continue + } + + event := &chatEventNotification{ + pointerUpdate: &chatpb.Pointer{ + Type: optionalPointer.kind.ToProto(), + Value: optionalPointer.value.ToProto(), + MemberId: memberId.ToProto(), + }, + } + if err := stream.notify(event, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + return + } + } + } +} + +func (s *Server) flushChats(ctx context.Context, owner *common.Account, memberId chat.MemberId, stream eventStream) { + log := s.log.WithFields(logrus.Fields{ + "method": "flushChats", + "member_id": memberId.String(), + }) + + chats, err := s.data.GetAllChatsForUserV2(ctx, memberId) + if err != nil { + log.WithError(err).Warn("failed get chats") + return + } + + // TODO: This needs to be far safer. + for _, chatId := range chats { + go func(chatId chat.ChatId) { + md, err := s.getMetadata(ctx, memberId, chatId) + if err != nil { + log.WithError(err).Warn("failed get metadata", zap.String("chat_id", chatId.String())) + return + } + + messages, err := s.getProtoChatMessages( + ctx, + chatId, + owner, + query.WithLimit(1), + query.WithDirection(query.Descending), + ) + if err != nil { + log.WithError(err).Warn("failed get chat messages", zap.String("chat_id", chatId.String())) + } + + event := &chatEventNotification{ + chatId: chatId, + chatUpdate: md, + } + if len(messages) > 0 { + event.messageUpdate = messages[0] + } + }(chatId) + } +} + +type chatEventNotification struct { + chatId chat.ChatId + ts time.Time + + chatUpdate *chatpb.Metadata + pointerUpdate *chatpb.Pointer + messageUpdate *chatpb.Message + isTyping *chatpb.IsTyping +} + +func (s *Server) asyncNotifyAll(chatId chat.ChatId, event *chatEventNotification) error { + event.ts = time.Now() + ok := s.chatEventChans.Send(chatId[:], event) + if !ok { + return errors.New("chat event channel is full") + } + + return nil +} + +func (s *Server) asyncChatEventStreamNotifier(workerId int, channel <-chan any) { + log := s.log.WithFields(logrus.Fields{ + "method": "asyncChatEventStreamNotifier", + "worker": workerId, + }) + + for value := range channel { + typedValue, ok := value.(*chatEventNotification) + if !ok { + log.Warn("channel did not receive expected struct") + continue + } + + log = log.WithField("chat_id", typedValue.chatId.String()) + + if time.Since(typedValue.ts) > time.Second { + log.Warn("channel notification latency is elevated") + } + + s.streamsMu.RLock() + for _, stream := range s.streams { + if err := stream.notify(typedValue, streamNotifyTimeout); err != nil { + log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream) + } + } + s.streamsMu.RUnlock() + } +}