1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-07-29 08:52:41 +00:00

Compare commits

...

166 commits

Author SHA1 Message Date
diamondburned 16a408bf30 wsutil: Refactored and decoupled structures for better thread safety 2020-10-28 10:19:22 -07:00
diamondburned 6c332ac145 {Voice,}Gateway: Fixed various race conditions
This commit fixes race conditions in both package voice, package
voicegateway and package gateway.

Originally, several race conditions exist when both the user's and the
pacemaker's goroutines both want to do several things to the websocket
connection. For example, the user's goroutine could be writing, and the
pacemaker's goroutine could trigger a reconnection. This is racey.

This issue is partially fixed by removing the pacer loop from package
heart and combining the ticker into the event (pacemaker) loop itself.

Technically, a race condition could still be triggered with care, but
the API itself never guaranteed any of those. As events are handled
using an internal loop into a channel, a race condition will not be
triggered just by handling events and writing to the websocket.
2020-10-22 10:47:27 -07:00
diamondburned 91ee92e9d5 Gateway: Fixed a race condition on ReconnectOP 2020-10-21 22:42:16 -07:00
diamondburned 86795e42a6 Session: Fixed a potential race condition on Close 2020-10-21 22:42:16 -07:00
Maximilian von Lindern 397d288927
API: fix errors in message pagination and streamline changes with other pagination methods (#150)
* API: fix faulty pagination behavior

This fix fixes a condition which lead to all messages getting fetched if the limit was a multiple of 100, instead of just the limit.

* API: add NewestMessages

* API: clarify MessageAfter docs

* API: adapt paginating methods for guild, member and message reaction to match the style of message's pagination methods

* API: return nil if no items were fetched

* API: remove Messages and Rename NewestMessages to Messages
2020-10-19 07:47:43 -07:00
diamondburned dec39c4c2d API: Fixed Messages{Before,After} fetching incorrectly beyond 100s 2020-10-18 22:14:49 -07:00
mavolin 6dabffb46c State: fix case where Role would return nil error, even though no role was found 2020-10-18 13:44:37 -07:00
diamondburned 1bec57523d Gateway: GuildSubscribeData should omit empty Channels map 2020-10-17 03:18:50 -07:00
diamondburned 86dd05da9e Gateway: Fixed empty Query on RequestGuildMembersData broken 2020-10-16 02:17:59 -07:00
mavolin 647efb8030 Discord: add Mention method to mentionable Snowflakes 2020-09-24 11:54:45 -07:00
diamondburned 64ab8c4f30 Bot: Fixed trailing backticks causing out of bound panic 2020-08-29 22:09:58 -07:00
mavolin 5acf9f3f22 Discord: fix invalid role mention generation 2020-08-24 16:32:51 -07:00
mavolin 7d5cc89ff0 API: add KickWithReason 2020-08-22 10:05:37 -07:00
diamondburned 6b4e26e839 wsutil: Improved internal code 2020-08-20 14:15:52 -07:00
diamondburned fd818e181e Gateway: GuildFolderID is now a signed int because Discord 2020-08-19 21:54:20 -07:00
diamondburned 87c648ae1d Discord: ParseSnowflake now uses ParseUint 2020-08-19 21:53:22 -07:00
diamondburned 3312c66515 Voice: Made EventLoop a valid struct value instead of nil pointer 2020-08-19 21:32:40 -07:00
diamondburned de61fd912d wsutil: Made PacemakerLoop valid as zero-value 2020-08-19 21:30:57 -07:00
diamondburned f0c73f4c99 State: Ready events now automatically reset the state 2020-08-18 10:20:48 -07:00
Maximilian von Lindern a7e9439109
Discord/API: implement changes to permission, allow and deny fields (#141) 2020-08-17 17:10:43 -07:00
diamondburned af7f413cea Gateway: Clarified GuildMemberListGroup.ID docs 2020-08-14 21:13:48 -07:00
diamondburned c819b56170 Gateway: Added a custom GuildFolderID type 2020-08-14 18:13:35 -07:00
diamondburned eb46a89e6c Gateway: Fixed GuildFolder.ID unmarshaling 2020-08-14 17:57:06 -07:00
diamondburned d888a5a7ac Bot: Added better middleware documentation 2020-08-11 17:31:29 -07:00
diamondburned 3db68bcb0e Bot: Allow hanging quotes if command has a custom parser 2020-08-11 13:44:32 -07:00
diamondburned 94cca0adca httputil: Fixed unlock of unlocked mutex bug 2020-08-04 14:09:43 -07:00
mavolin 2a032ebfab Discord: add watching activity 2020-08-03 17:46:04 -07:00
diamondburned 77b1b08bce Heart: Better synchronization on close methods 2020-07-30 12:44:50 -07:00
mavolin 362929fad5 Webhook: fix incorrect order of parameters 2020-07-29 20:03:24 -07:00
mavolin ba1fc650d1 API: fix wrong typed Snowflake 2020-07-29 20:03:24 -07:00
Maximilian von Lindern 1585797b52 *: Linting and typo fixes (#134)
* Linting and typo fixes

* Linting and typo fixes

* revert comma fix
2020-07-29 16:58:33 -07:00
mavolin 8baf8ee84b Multipartutil: move back to package api 2020-07-29 16:58:33 -07:00
Maximilian von Lindern 908ef96089 Discord: Uint64 typed Snowflakes (#132)
* Use typed Snowflakes if possible

* Discord: make Snowflakes uint64

* Fix errors that emerged because of new typing
2020-07-29 16:58:33 -07:00
Maximilian von Lindern 32789bb6e2 *: Separate utils and internal (#129)
* Utils: move package utils/heart to internal/heart

* Utils: move package utils/moreatomic to internal/moreatomic

* Utils: move package utils/zlib to internal/zlib
2020-07-29 16:58:33 -07:00
Maximilian von Lindern 78c36f13cd Discord: Rename Snowflake and Timestamp Valid methods (#128)
* Discord: rename Snowflake.Valid() to IsValid()

* Discord: rename Timestamp.Valid() to IsValid()
2020-07-29 16:58:33 -07:00
Maximilian von Lindern e1d9685cdb API: separate token-based and bot-based interactions with webhooks (#130)
* API: separate token-based and bot-based interactions with webhooks

* API: move writeMultipart to internal/multipartutil

* Multipartutil: fix double filetype-suffix
2020-07-29 16:58:33 -07:00
mavolin ba4b224168 handler: move package from /handler to /utils/handler 2020-07-29 16:58:33 -07:00
diamondburned e79132f2c5 State: Breaking API to fix race conditions in store 2020-07-29 16:58:33 -07:00
Tadeo Kondrak b8f6fbbda9 Gateway: Fix type of GuildFolder.ID 2020-07-29 16:58:33 -07:00
Tadeo Kondrak d290b0d01c *: Add typed Snowflake IDs (#122)
This PR closes #120.
2020-07-29 16:58:33 -07:00
diamondburned 24f7ed0499 Gateway: ReconnectCtx now returns error; fixed test 2020-07-18 18:33:07 -07:00
diamondburned a929817c0f Handler: Fixed data race in test 2020-07-18 18:25:00 -07:00
diamondburned 1c8aaaefcc State: Fixed individual message fetch missing GuildID 2020-07-17 11:35:44 -07:00
diamondburned d18298aca9 Discord: Updated message's URL 2020-07-17 11:33:57 -07:00
diamondburned 35e143a99f Handler: Added blocking send cleanup to avoid goroutine leak 2020-07-15 23:11:20 -07:00
diamondburned 6717f8002c Gateway: Fixed autoreconnect misusing context 2020-07-15 16:39:40 -07:00
diamondburned a1038cb8bb Gateway: Fixed wrong usage of Context in Gateway reconnection 2020-07-15 16:32:53 -07:00
diamondburned 880691c51b Handler: Fixed inconsistency in documentation 2020-07-15 00:48:50 -07:00
diamondburned cb8567f006 Handler: Added examples as comments for documentation 2020-07-15 00:05:35 -07:00
diamondburned 18024526fe Handler: Added support for channel event handlers 2020-07-14 23:57:50 -07:00
diamondburned 9d7f5cb953 Gateway: Deprecated useless type definitions for embedded structs 2020-07-14 21:38:31 -07:00
diamondburned 5b37b2ab0d Gateway: Allow for longer timeouts 2020-07-14 18:47:52 -07:00
diamondburned c1885067d7 Gateway: Allow for more lenient gateway bursts 2020-07-14 18:47:15 -07:00
diamondburned 7572caad31 Discord: Added Relationship API methods; moved structs around 2020-07-14 18:01:24 -07:00
diamondburned 56b1a7cce8 Bot: Help generators now allow generating hidden commands 2020-07-14 16:33:21 -07:00
diamondburned bf7ca8450d API: Fixed Ban panicking 2020-07-11 18:51:01 -07:00
diamondburned 712a061e8e API: Added SetNote for user accounts 2020-07-11 14:27:03 -07:00
diamondburned 91e494ba51 Gateway: Changed Relationship struct for type and name claritifcation 2020-07-11 13:25:29 -07:00
diamondburned edb8a46ef2 Gateway: Added intent helpers and more context API support 2020-07-11 12:50:32 -07:00
diamondburned f33b4ff7d8 wsutil: API changed to support contexts 2020-07-11 12:49:28 -07:00
diamondburned a0785bd657 CI: Increased time limit to account for slow integration tests 2020-07-11 00:06:31 -07:00
diamondburned d3d9811276 Gateway: Added Relationship events and handlers; minor reformatting 2020-07-11 00:02:57 -07:00
diamondburned (Forefront) 16ed406c53 Session: Fixed a panic bug when the gateway fails 2020-06-29 11:00:07 -07:00
diamondburned (Forefront) 01021f0902 Fixed a compile bug 2020-06-19 00:59:44 -07:00
diamondburned (Forefront) 88dd0f8995 State now handles MsgCreate's missing Member.User field, some bug fixes
This addresses discord/discord-api-docs#1440.

State documentation has been added, which documents the store and
handlers as well.

Bug fixes include:

- PreHandler being called after the state handler; it is now called
before as documented.
- Minor behavior changes regarding Guild Create events. Refer to State's
documentation.
2020-06-19 00:33:22 -07:00
Maximilian von Lindern 1373e42fe1
State: fix State.Message not working when the message's channel is not found in the Store (#117)
* State: fix State.Message not working when the message's channel is not found in the Store

* State: fix State.Message not working when the message's channel is not found in the Store
2020-06-08 07:30:16 -07:00
Maximilian von Lindern de3d0e2160
Gateway: Split GuildCreateEvent (#116)
* Session: fix event handler loop not getting properly closed

* Implement #113

* Session: move guild events to state

* Session: close hStop
2020-06-06 13:47:15 -07:00
mavolin 943ca00ae5 State: implement #114 2020-06-06 10:24:34 -07:00
mavolin efd2ce4c03 State: reduce times a go routine is spawned 2020-06-06 10:24:34 -07:00
diamondburned (Forefront) 9ce0620652 Cleaned up go.mod 2020-06-01 13:52:15 -07:00
diamondburned (Forefront) 9747675741 CI: Fixed syntax 2020-05-30 15:22:13 -07:00
diamondburned (Forefront) 77d6067340 CI: Better mutual exclusivity of unit and integration tests 2020-05-30 15:19:51 -07:00
diamondburned (Forefront) 783dfe7ba6 CI: Unit and integration tests are now mutually exclusive 2020-05-30 15:15:32 -07:00
diamondburned (Forefront) f91518f3c6 CI: haha json go brr 2020-05-30 14:45:22 -07:00
diamondburned (Forefront) bafeb1082a CI: Added Dismock, better coverage parsing 2020-05-30 14:32:23 -07:00
ks129 93fbfd98d0 Fix aliases appending
Add 3 dots on appending to merge 2 slices.
2020-05-24 23:08:55 -07:00
ks129 23d97044ec Simplify aliases adding
- Removed duplicates check
- Fixed docstring
2020-05-24 23:08:55 -07:00
ks129 75fe1bd03a Implement command aliases
- Add alias parsing to `Context.findCommand`.
- Add new function to `Subcommand`: `AddAliases` that add new alias(es) to command.
- Added `Aliases` property to `MethodContext`
2020-05-24 23:08:55 -07:00
mavolin 960ba486bd API: code cleanup 2020-05-24 17:28:04 -07:00
mavolin a07f343b39 API: fix id field getting sent on EditChannelPermission 2020-05-24 17:28:04 -07:00
mavolin 52bec08cc6 Discord: fix discovery splash url not getting properly calculated 2020-05-24 17:24:58 -07:00
mavolin 6c3b1e0c56 API: verify Embed and AllowedMentions on message edit 2020-05-24 16:57:23 -07:00
mavolin 55e9c28d37 API: fix Message.Author.ID getting sent instead of Message.ID 2020-05-24 09:08:47 -07:00
mavolin 46b001548d Discord: fix wrong hash for discovery splash url 2020-05-24 09:08:24 -07:00
mavolin 19b970bad7 Discord: fix #105 2020-05-23 14:24:03 -07:00
Maximilian von Lindern ce38507fb0
Discord: fixes around meta images (#104)
* API: fix illogical order of parameters

* Discord: fixes around meta images
2020-05-23 10:17:30 -07:00
mavolin 6fbc3e6afd Discord: fix #102 2020-05-23 09:04:12 -07:00
mavolin 147b01641b API: fix illogical order of parameters 2020-05-22 19:19:08 -07:00
mavolin b67b993095 API: fix illogical order of parameters 2020-05-22 19:14:01 -07:00
mavolin 6cc6d05f5f API: use CreateWebhookData instead of direct arguments 2020-05-22 19:00:51 -07:00
mavolin 93d9323b3b API: fix accidental append instead of prepend 2020-05-22 17:09:55 -07:00
mavolin 7b52582c93 API: fix unlimited pagination error 2020-05-22 16:18:01 -07:00
mavolin e4b43c0a83 Discord: fix invalid calc of after 2020-05-22 16:01:09 -07:00
mavolin 530bff74a2 Discord: fix faulty default image link 2020-05-22 14:40:36 -07:00
mavolin eefb6d731c API: implement #93 2020-05-22 12:27:57 -07:00
mavolin a76c9031c1 API: fix #91 2020-05-22 10:52:30 -07:00
mavolin 5fefaf07c4 API: fix #89 2020-05-22 10:47:46 -07:00
Maximilian von Lindern 68701704a1 Discord: fix wrong field naming 2020-05-22 08:14:35 -07:00
mavolin 68d3129bfd Discord: add docs 2020-05-22 08:14:35 -07:00
mavolin f4be7971ee Discord: add missing fields to Guild struct 2020-05-22 08:14:35 -07:00
diamondburned (Forefront) 9da01cccb3 Voice: Fixed a potential Write() stalling bug 2020-05-20 15:05:50 -07:00
mavolin c5f1bf4753 Discord: add docs for auditlog.go 2020-05-18 10:25:17 -07:00
diamondburned (Forefront) 53c1ea0f0d State: Fixed Discord not setting GuildID for Ready.Guild.Channels 2020-05-17 23:11:14 -07:00
diamondburned (Forefront) 64c6ca7916 Gateway: Fixed GuildCreateEvent not having Channels 2020-05-17 22:48:16 -07:00
diamondburned (Forefront) dc303a8635 Merge branch 'master' of github.com:diamondburned/arikawa 2020-05-17 13:31:15 -07:00
diamondburned (Forefront) dfcf6770c3 State: Fixed message out-of-bound during copying 2020-05-17 13:31:08 -07:00
mavolin 805df29c2e Discord: fix #83 2020-05-17 10:34:47 -07:00
mavolin 1c53befad4 API: fix wrong endpoint for GuildWidget and ModifyGuildWidget 2020-05-17 10:34:11 -07:00
diamondburned 032ae736ab
Merge pull request #80 from mavolin/lowercase-errors 2020-05-16 17:32:17 -07:00
diamondburned 05f01964fe
Merge pull request #81 from mavolin/77-no-limits 2020-05-16 17:31:43 -07:00
mavolin b3fabae701
API: implement #77 2020-05-17 01:35:57 +02:00
mavolin 41ce1f389e
make all error messages lowercase 2020-05-16 23:14:49 +02:00
diamondburned ff8ebcbacf
Merge pull request #79 from mavolin/76-async-state 2020-05-16 13:41:37 -07:00
mavolin 38b2d4d2b4
State: fix errors not returned 2020-05-16 22:36:46 +02:00
mavolin 9ce1a967d8
Merge remote-tracking branch 'origin/76-async-state' into 76-async-state 2020-05-16 22:19:47 +02:00
mavolin 6202f53ebb
State: implement #76 2020-05-16 22:15:59 +02:00
mavolin d9e0580e45
State: implement #76 2020-05-16 22:13:40 +02:00
diamondburned (Forefront) 130e60c162 Merge branch 'master' of github.com:diamondburned/arikawa 2020-05-16 13:06:12 -07:00
diamondburned (Forefront) 5aec467779 State: Fixed incoming messages being backwards in order 2020-05-16 13:05:11 -07:00
diamondburned 24e29ecb38
Merge pull request #75 from mavolin/74-guild-widget 2020-05-16 10:19:25 -07:00
mavolin 22a6994c50
Discord: implement #74 2020-05-16 17:57:25 +02:00
diamondburned 1df6bf61fc Merge pull request #73 from mavolin/master 2020-05-15 13:21:16 -07:00
Maximilian von Lindern 2c98f4e8e4
Discord: fix typo 2020-05-15 22:13:15 +02:00
mavolin 2c4022734b
Discord: create XURLWithType methods 2020-05-15 22:07:27 +02:00
mavolin ceb6986749
Discord: add ImageType 2020-05-15 21:52:45 +02:00
diamondburned aa99f50b9c
Merge pull request #72 from mavolin/70-endpoint
Discord: fix wrong URL endpoint
2020-05-15 12:16:22 -07:00
mavolin 6ee5a1b26d
Discord: fix #70 2020-05-15 20:39:08 +02:00
mavolin 48a13e1fe8
API: add missing trailing linefeed 2020-05-15 20:38:22 +02:00
diamondburned 0c21aa8571
Merge pull request #69 from mavolin/46-guild-preview
API: Add missing GuildPreview endpoint
2020-05-15 11:28:29 -07:00
mavolin b5dedf9408
API/Discord: add GuildPreview 2020-05-15 20:10:35 +02:00
diamondburned 65641652c4
Merge pull request #67 from mavolin/55-user-connections 2020-05-15 10:50:11 -07:00
diamondburned f05eb0d5a8
Merge branch 'master' into 55-user-connections 2020-05-15 10:49:46 -07:00
diamondburned 6f6cf2b85c
Merge pull request #66 from mavolin/63-edit 2020-05-15 10:45:53 -07:00
mavolin 7cb3520bc7
API: re-add pointer to embed to preserve optionality 2020-05-15 19:40:00 +02:00
diamondburned 2d8e08e5e9
Merge pull request #68 from mavolin/54-create-group 2020-05-15 10:39:58 -07:00
mavolin b6fee46f69
API: remove comment 2020-05-15 19:37:10 +02:00
mavolin 4eef15ec7d
API: add UserConnections method (#55) 2020-05-15 19:31:23 +02:00
mavolin a963ea46f1
API: fix missing doc 2020-05-15 19:29:04 +02:00
mavolin 52a7582dab
API: remove the pointer of Embed to concur with SendMessage 2020-05-15 19:19:13 +02:00
mavolin 11cf1eb769
API: implement #63 2020-05-15 19:17:52 +02:00
mavolin e362b10084
API: update docs 2020-05-15 19:14:37 +02:00
diamondburned 795a69ca7d Bot: Added EditableCommands 2020-05-14 18:52:15 -07:00
diamondburned 6b8628804f
Merge pull request #65 from diamondburned/nf-deprecation 2020-05-14 18:16:17 -07:00
diamondburned 17c620dd5a Bot: Updated arguments test 2020-05-14 18:07:16 -07:00
diamondburned fe950de9e0 Examples: Fixed up advanced_bot to make it up to date 2020-05-14 18:01:48 -07:00
diamondburned (Forefront) 7d683a2ace Examples: Updated advanced_bot to the new API 2020-05-14 14:20:23 -07:00
diamondburned 729979088c Bot: Added more tests and the Help API 2020-05-14 14:04:18 -07:00
diamondburned (Forefront) 6613aa5b41 Bot: Added tests for middlewares 2020-05-14 13:59:17 -07:00
diamondburned (Forefront) e556d2afad Bot: Partially implemented middlewares 2020-05-14 13:59:17 -07:00
diamondburned (Forefront) 67430c6d7a Bot: Added Plumb support, fixed tests
Merged

Merged
2020-05-14 13:59:17 -07:00
diamondburned (Forefront) 9e59402591 Bot: Added tests for middlewares 2020-05-14 13:59:17 -07:00
diamondburned (Forefront) 964e8cdf13 Bot: Partially implemented middlewares 2020-05-14 13:59:17 -07:00
diamondburned (Forefront) 1ca7d1c62c Discord: Fixed Snowflake Valid() returning true for null 2020-05-14 01:03:48 -07:00
diamondburned (Forefront) 763999a81e API: Added extra Emoji documentation 2020-05-14 00:49:57 -07:00
diamondburned (Forefront) aa07ff9a43 Discord: Fixed Emoji.APIString() mishandling Unicode characters 2020-05-14 00:49:10 -07:00
diamondburned (Forefront) c4504808a2 Bot: Updated message URL for arguments 2020-05-14 00:19:51 -07:00
diamondburned 5ac262163e
Merge pull request #64 from matthewpi/fix/voice-close 2020-05-13 14:55:49 -07:00
Matthew Penner 592d2f7172 Fix null snowflakes being formatted as 18446744073709551615 2020-05-13 15:51:26 -06:00
Matthew Penner 3aa92c8f05 Only add the guild to SessionErrors if the error is not nil 2020-05-13 15:38:07 -06:00
Matthew Penner 60346f23bb Close voice connections when Close() is called 2020-05-13 14:43:00 -06:00
diamondburned (Forefront) adb23eeb8e Gateway: Added new InviteCreateEvent and InviteDeleteEvent 2020-05-12 17:51:23 -07:00
diamondburned (Forefront) 91bc93f331 API: Fixed integration test 2020-05-12 17:34:36 -07:00
diamondburned (Forefront) 694c074902 API: Fixed EditMessage test 2020-05-12 17:29:22 -07:00
diamondburned (Forefront) ae793848aa Utils: Exposed NullableTData structs 2020-05-12 17:09:43 -07:00
diamondburned (Forefront) cc4e8c0966 API: Added test for EditMessage 2020-05-12 17:09:32 -07:00
diamondburned (Forefront) 9f1f7547b9 API: Reversed EditMessage API for backwards compatibility and consistency 2020-05-12 16:51:35 -07:00
119 changed files with 6546 additions and 3724 deletions

View file

@ -1,39 +1,52 @@
image: golang:alpine
variables:
GO111MODULE: "on"
CGO_ENABLED: "0"
COV: "/tmp/cov_results"
before_script:
- apk add git
stages:
- test
build_test:
stage: test
script:
- go build ./...
unit_test:
stage: test
script:
- go test -v -coverprofile $COV ./...
- go tool cover -func $COV
| grep -F 'total:'
| sed -E 's|total:\s+\(.*?\)\s+([0-9]+\.[0-9]+%)|TEST_COVERAGE=\1|'
integration_test:
stage: test
script:
# Don't run if these variables aren't provided.
- '[ ! "$BOT_TOKEN" ] && exit'
# go get first, so it doesn't count towards the timeout.
- go get ./...
# Timeout test after 120 seconds (2 minutes)
- timeout 120 go test -tags integration -v -coverprofile $COV ./...
- go tool cover -func $COV
| grep -F 'total:'
| sed -E 's|total:\s+\(.*?\)\s+([0-9]+\.[0-9]+%)|TEST_COVERAGE=\1|'
{
"image": "golang:alpine",
"variables": {
"GO111MODULE": "on",
"CGO_ENABLED": "0",
"COV": "/tmp/cov_results",
"dismock": "github.com/mavolin/dismock/pkg/dismock",
# used only in integration_test
"tested": "./api,./gateway,./bot,./discord"
},
"before_script": [
"apk add git"
],
"stages": [
"build",
"test"
],
"build_test": {
"stage": "build",
"script": [
"go build ./..."
]
},
"unit_test": {
"stage": "test",
"timeout": "2m", # 2 minutes
# Don't run the test if we have a $BOT_TOKEN, because
# integration_test will run instead.
"except": {
"variables": [ "$BOT_TOKEN" ]
},
"script": [
"go test -v -coverprofile $COV ./...",
"go tool cover -func $COV"
]
},
"integration_test": {
"stage": "test",
"timeout": "5m", # 5 minutes
# Run the test only if we have $BOT_TOKEN, else fallback to unit
# tests.
"only": {
"variables": [ "$BOT_TOKEN" ]
},
"script": [
"go get ./...",
# Test this package along with dismock.
"go test -coverpkg $tested -coverprofile $COV -tags integration -v ./... $dismock",
"go tool cover -func $COV"
]
}
}

View file

@ -3,7 +3,7 @@
[![Pipeline status](https://gitlab.com/diamondburned/arikawa/badges/master/pipeline.svg?style=flat-square)](https://gitlab.com/diamondburned/arikawa/pipelines )
[![ Coverage](https://gitlab.com/diamondburned/arikawa/badges/master/coverage.svg?style=flat-square)](https://gitlab.com/diamondburned/arikawa/commits/master )
[![ Report Card](https://goreportcard.com/badge/github.com/diamondburned/arikawa?style=flat-square )](https://goreportcard.com/report/github.com/diamondburned/arikawa)
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue?style=flat-square )](https://godoc.org/github.com/diamondburned/arikawa )
[![Godoc Reference](https://img.shields.io/badge/godoc-reference-blue?style=flat-square )](https://pkg.go.dev/github.com/diamondburned/arikawa )
[![ Examples](https://img.shields.io/badge/Example-__example%2F-blueviolet?style=flat-square )](https://github.com/diamondburned/arikawa/tree/master/_example )
[![Discord Gophers](https://img.shields.io/badge/Discord%20Gophers-%23arikawa-%237289da?style=flat-square)](https://discord.gg/7jSf85J )
[![ Hime Arikawa](https://img.shields.io/badge/Hime-Arikawa-ea75a2?style=flat-square )](https://hime-goto.fandom.com/wiki/Hime_Arikawa )
@ -15,7 +15,7 @@ A Golang library for the Discord API.
### [Simple](https://github.com/diamondburned/arikawa/tree/master/_example/simple)
Simple bot example without any state. All it does is logging messages sent into
the console. Run with `BOT_TOKEN="TOKEN" go run .`
the console. Run with `BOT_TOKEN="TOKEN" go run .`.
### [Undeleter](https://github.com/diamondburned/arikawa/tree/master/_example/undeleter)
@ -24,8 +24,8 @@ everything, including messages. It detects when someone deletes a message,
logging the content into the console.
This example demonstrates the PreHandler feature of this library. PreHandler
calls all handlers that are registered (separately from session), calling them
before the state is updated.
calls all handlers that are registered (separately from the session), calling
them before the state is updated.
### [Advanced Bot](https://github.com/diamondburned/arikawa/tree/master/_example/advanced_bot)
@ -34,7 +34,7 @@ that's built-in. The router turns exported struct methods into commands, its
arguments into command arguments, and more.
The library has a pretty detailed documentation available in [GoDoc
Reference](https://godoc.org/github.com/diamondburned/arikawa/bot).
Reference](https://pkg.go.dev/github.com/diamondburned/arikawa/bot).
## Comparison: Why not discordgo?
@ -63,20 +63,10 @@ custom remote or local state storage.
things in the state, which is useful for keeping it updated.
- No code generation: just so the library is a lot easier to maintain.
## You-should-knows
- ~~The bot will fatally exit if it fails to reconnect to the Gateway after a
certain amount of times. This is changeable in `gateway.WSFatal`, or
`(*Gateway).FatalLog`.~~
- ~~The bot will error out if the initial connection fails. However,
reconnections will be retried forever until it succeeds.~~ This is no longer
true. The bot will retry until `WSRetries` is reached, then the error will go
to `(*Gateway).FatalError` or `(*Gateway).Wait()`.
## Testing
The package includes integration tests that require `$BOT_TOKEN`. To run these
tests, do
tests, do:
```sh
export BOT_TOKEN="<BOT_TOKEN>"

View file

@ -10,6 +10,7 @@ import (
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/arguments"
"github.com/diamondburned/arikawa/bot/extras/middlewares"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
)
@ -19,57 +20,51 @@ type Bot struct {
Ctx *bot.Context
}
func (bot *Bot) Setup(sub *bot.Subcommand) {
// Only allow people in guilds to run guildInfo.
sub.AddMiddleware("GuildInfo", middlewares.GuildOnly(bot.Ctx))
}
// Help prints the default help message.
func (bot *Bot) Help(m *gateway.MessageCreateEvent) (string, error) {
func (bot *Bot) Help(*gateway.MessageCreateEvent) (string, error) {
return bot.Ctx.Help(), nil
}
// Add demonstrates the usage of typed arguments. Run it with "~add 1 2".
func (bot *Bot) Add(m *gateway.MessageCreateEvent, a, b int) error {
content := fmt.Sprintf("%d + %d = %d", a, b, a+b)
_, err := bot.Ctx.SendMessage(m.ChannelID, content, nil)
return err
func (bot *Bot) Add(_ *gateway.MessageCreateEvent, a, b int) (string, error) {
return fmt.Sprintf("%d + %d = %d", a, b, a+b), nil
}
// Ping is a simple ping example, perhaps the most simple you could make it.
func (bot *Bot) Ping(m *gateway.MessageCreateEvent) error {
_, err := bot.Ctx.SendMessage(m.ChannelID, "Pong!", nil)
return err
func (bot *Bot) Ping(*gateway.MessageCreateEvent) (string, error) {
return "Pong!", nil
}
// Say demonstrates how arguments.Flag could be used without the flag library.
func (bot *Bot) Say(
m *gateway.MessageCreateEvent, f *arguments.Flag) (string, error) {
args := f.String()
if args == "" {
// Empty message, ignore
return "", nil
func (bot *Bot) Say(_ *gateway.MessageCreateEvent, f bot.RawArguments) (string, error) {
if f != "" {
return string(f), nil
}
return args, nil
return "", errors.New("missing content")
}
// GuildInfo demonstrates the use of command flags, in this case the GuildOnly
// flag.
func (bot *Bot) GーGuildInfo(m *gateway.MessageCreateEvent) (string, error) {
g, err := bot.Ctx.Guild(m.GuildID)
// GuildInfo demonstrates the GuildOnly middleware done in (*Bot).Setup().
func (bot *Bot) GuildInfo(m *gateway.MessageCreateEvent) (string, error) {
g, err := bot.Ctx.GuildWithCount(m.GuildID)
if err != nil {
return "", fmt.Errorf("Failed to get guild: %v", err)
return "", fmt.Errorf("failed to get guild: %v", err)
}
return fmt.Sprintf(
"Your guild is %s, and its maximum members is %d",
g.Name, g.MaxMembers,
g.Name, g.ApproximateMembers,
), nil
}
// Repeat tells the bot to wait for the user's response, then repeat what they
// said.
func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) {
_, err := bot.Ctx.SendMessage(m.ChannelID,
"What do you want me to say?", nil)
_, err := bot.Ctx.SendMessage(m.ChannelID, "What do you want me to say?", nil)
if err != nil {
return "", err
}
@ -77,6 +72,8 @@ func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
// This might miss events that are sent immediately after. To make sure all
// events are caught, ChanFor should be used.
v := bot.Ctx.WaitFor(ctx, func(v interface{}) bool {
// Incoming event is a message create event:
mg, ok := v.(*gateway.MessageCreateEvent)
@ -89,7 +86,7 @@ func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) {
})
if v == nil {
return "", errors.New("Timed out waiting for response.")
return "", errors.New("timed out waiting for response")
}
ev := v.(*gateway.MessageCreateEvent)
@ -98,9 +95,7 @@ func (bot *Bot) Repeat(m *gateway.MessageCreateEvent) (string, error) {
// Embed is a simple embed creator. Its purpose is to demonstrate the usage of
// the ParseContent interface, as well as using the stdlib flag package.
func (bot *Bot) Embed(
m *gateway.MessageCreateEvent, f *arguments.Flag) (*discord.Embed, error) {
func (bot *Bot) Embed(_ *gateway.MessageCreateEvent, f arguments.Flag) (*discord.Embed, error) {
fs := arguments.NewFlagSet()
var (
@ -115,12 +110,12 @@ func (bot *Bot) Embed(
}
if len(fs.Args()) < 1 {
return nil, fmt.Errorf("Usage: embed [flags] content...\n" + fs.Usage())
return nil, fmt.Errorf("usage: embed [flags] content...\n" + fs.Usage())
}
// Check if the color string is valid.
if !strings.HasPrefix(*color, "#") || len(*color) != 7 {
return nil, errors.New("Invalid color, format must be #hhhhhh")
return nil, errors.New("invalid color, format must be #hhhhhh")
}
// Parse the color into decimal numbers.

View file

@ -7,6 +7,7 @@ import (
"strings"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/middlewares"
"github.com/diamondburned/arikawa/gateway"
)
@ -25,18 +26,16 @@ func (d *Debug) Setup(sub *bot.Subcommand) {
// Manually set the usage for each function.
sub.ChangeCommandInfo("GOOS", "",
"Prints the current operating system")
sub.ChangeCommandInfo("GOOS", "GOOS", "Prints the current operating system")
sub.ChangeCommandInfo("GC", "GC", "Triggers the garbage collector")
sub.ChangeCommandInfo("Goroutines", "", "Prints the current number of Goroutines")
sub.ChangeCommandInfo("GC", "",
"Triggers the garbage collecto")
sub.ChangeCommandInfo("Goroutines", "",
"Prints the current number of Goroutines")
sub.Hide("Die")
sub.AddMiddleware("Die", middlewares.AdminOnly(d.Context))
}
// ~go goroutines
func (d *Debug) Goroutines(m *gateway.MessageCreateEvent) (string, error) {
func (d *Debug) Goroutines(*gateway.MessageCreateEvent) (string, error) {
return fmt.Sprintf(
"goroutines: %d",
runtime.NumGoroutine(),
@ -44,19 +43,19 @@ func (d *Debug) Goroutines(m *gateway.MessageCreateEvent) (string, error) {
}
// ~go GOOS
func (d *Debug) RーGOOS(m *gateway.MessageCreateEvent) (string, error) {
func (d *Debug) GOOS(*gateway.MessageCreateEvent) (string, error) {
return strings.Title(runtime.GOOS), nil
}
// ~go GC
func (d *Debug) RーGC(m *gateway.MessageCreateEvent) (string, error) {
func (d *Debug) GC(*gateway.MessageCreateEvent) (string, error) {
runtime.GC()
return "Done.", nil
}
// ~go die
// This command will be hidden from ~help by default.
func (d *Debug) AーDie(m *gateway.MessageCreateEvent) error {
func (d *Debug) Die(m *gateway.MessageCreateEvent) error {
log.Fatalln("User", m.Author.Username, "killed the bot x_x")
return nil
}

View file

@ -19,6 +19,7 @@ func main() {
wait, err := bot.Start(token, commands, func(ctx *bot.Context) error {
ctx.HasPrefix = bot.NewPrefix("!", "~")
ctx.EditableCommands = true
// Subcommand demo, but this can be in another package.
ctx.MustRegisterSubcommand(&Debug{})

View file

@ -5,8 +5,8 @@ import (
"os"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
"github.com/diamondburned/arikawa/state"
"github.com/diamondburned/arikawa/utils/handler"
)
// To run, do `BOT_TOKEN="TOKEN HERE" go run .`

View file

@ -13,10 +13,10 @@ import (
var (
BaseEndpoint = "https://discord.com"
APIVersion = "6"
APIPath = "/api/v" + APIVersion
Version = "6"
Path = "/api/v" + Version
Endpoint = BaseEndpoint + APIPath + "/"
Endpoint = BaseEndpoint + Path + "/"
EndpointGateway = Endpoint + "gateway"
EndpointGatewayBot = EndpointGateway + "/bot"
)
@ -34,7 +34,7 @@ func NewClient(token string) *Client {
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
ses := Session{
Limiter: rate.NewLimiter(APIPath),
Limiter: rate.NewLimiter(Path),
Token: token,
UserAgent: UserAgent,
}

View file

@ -9,7 +9,7 @@ import (
var EndpointChannels = Endpoint + "channels/"
// Channels returns a list of guild channel objects.
func (c *Client) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
func (c *Client) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
var chs []discord.Channel
return chs, c.RequestJSON(&chs, "GET", EndpointGuilds+guildID.String()+"/channels")
}
@ -56,7 +56,7 @@ type CreateChannelData struct {
// CategoryID is the id of the parent category for a channel.
//
// Channel Types: Text, News, Store, Voice
CategoryID discord.Snowflake `json:"parent_id,string,omitempty"`
CategoryID discord.ChannelID `json:"parent_id,string,omitempty"`
// NSFW specifies whether the channel is nsfw.
//
// Channel Types: Text, News, Store.
@ -68,7 +68,7 @@ type CreateChannelData struct {
// Requires the MANAGE_CHANNELS permission.
// Fires a Channel Create Gateway event.
func (c *Client) CreateChannel(
guildID discord.Snowflake, data CreateChannelData) (*discord.Channel, error) {
guildID discord.GuildID, data CreateChannelData) (*discord.Channel, error) {
var ch *discord.Channel
return ch, c.RequestJSON(
&ch, "POST",
@ -79,7 +79,7 @@ func (c *Client) CreateChannel(
type MoveChannelData struct {
// ID is the channel id.
ID discord.Snowflake `json:"id"`
ID discord.ChannelID `json:"id"`
// Position is the sorting position of the channel
Position option.Int `json:"position"`
}
@ -87,15 +87,15 @@ type MoveChannelData struct {
// MoveChannel modifies the position of channels in the guild.
//
// Requires MANAGE_CHANNELS.
func (c *Client) MoveChannel(guildID discord.Snowflake, datum []MoveChannelData) error {
func (c *Client) MoveChannel(guildID discord.GuildID, data []MoveChannelData) error {
return c.FastRequest(
"PATCH",
EndpointGuilds+guildID.String()+"/channels", httputil.WithJSONBody(datum),
EndpointGuilds+guildID.String()+"/channels", httputil.WithJSONBody(data),
)
}
// Channel gets a channel by ID. Returns a channel object.
func (c *Client) Channel(channelID discord.Snowflake) (*discord.Channel, error) {
func (c *Client) Channel(channelID discord.ChannelID) (*discord.Channel, error) {
var channel *discord.Channel
return channel, c.RequestJSON(&channel, "GET", EndpointChannels+channelID.String())
}
@ -147,13 +147,13 @@ type ModifyChannelData struct {
Permissions *[]discord.Overwrite `json:"permission_overwrites,omitempty"`
// CategoryID is the id of the new parent category for a channel.
// Channel Types: Text, News, Store, Voice
CategoryID discord.Snowflake `json:"parent_id,string,omitempty"`
CategoryID discord.ChannelID `json:"parent_id,string,omitempty"`
}
// ModifyChannel updates a channel's settings.
//
// Requires the MANAGE_CHANNELS permission for the guild.
func (c *Client) ModifyChannel(channelID discord.Snowflake, data ModifyChannelData) error {
func (c *Client) ModifyChannel(channelID discord.ChannelID, data ModifyChannelData) error {
return c.FastRequest("PATCH", EndpointChannels+channelID.String(), httputil.WithJSONBody(data))
}
@ -163,28 +163,38 @@ func (c *Client) ModifyChannel(channelID discord.Snowflake, data ModifyChannelDa
// Channel Update Gateway event will fire for each of them.
//
// Fires a Channel Delete Gateway event.
func (c *Client) DeleteChannel(channelID discord.Snowflake) error {
func (c *Client) DeleteChannel(channelID discord.ChannelID) error {
return c.FastRequest("DELETE", EndpointChannels+channelID.String())
}
// https://discord.com/developers/docs/resources/channel#edit-channel-permissions-json-params
type EditChannelPermissionData struct {
// Type is either "role" or "member".
Type discord.OverwriteType `json:"type"`
// Allow is a permission bit set for granted permissions.
Allow discord.Permissions `json:"allow,string"`
// Deny is a permission bit set for denied permissions.
Deny discord.Permissions `json:"deny,string"`
}
// EditChannelPermission edits the channel's permission overwrites for a user
// or role in a channel. Only usable for guild channels.
//
// Requires the MANAGE_ROLES permission.
func (c *Client) EditChannelPermission(
channelID discord.Snowflake, overwrite discord.Overwrite) error {
channelID discord.ChannelID, overwriteID discord.Snowflake, data EditChannelPermissionData) error {
url := EndpointChannels + channelID.String() + "/permissions/" + overwrite.ID.String()
overwrite.ID = 0
return c.FastRequest("PUT", url, httputil.WithJSONBody(overwrite))
return c.FastRequest(
"PUT", EndpointChannels+channelID.String()+"/permissions/"+overwriteID.String(),
httputil.WithJSONBody(data),
)
}
// DeleteChannelPermission deletes a channel permission overwrite for a user or
// role in a channel. Only usable for guild channels.
//
// Requires the MANAGE_ROLES permission.
func (c *Client) DeleteChannelPermission(channelID, overwriteID discord.Snowflake) error {
func (c *Client) DeleteChannelPermission(channelID discord.ChannelID, overwriteID discord.Snowflake) error {
return c.FastRequest(
"DELETE",
EndpointChannels+channelID.String()+"/permissions/"+overwriteID.String(),
@ -193,13 +203,13 @@ func (c *Client) DeleteChannelPermission(channelID, overwriteID discord.Snowflak
// Typing posts a typing indicator to the channel. Undocumented, but the client
// usually clears the typing indicator after 8-10 seconds (or after a message).
func (c *Client) Typing(channelID discord.Snowflake) error {
func (c *Client) Typing(channelID discord.ChannelID) error {
return c.FastRequest("POST", EndpointChannels+channelID.String()+"/typing")
}
// PinnedMessages returns all pinned messages in the channel as an array of
// message objects.
func (c *Client) PinnedMessages(channelID discord.Snowflake) ([]discord.Message, error) {
func (c *Client) PinnedMessages(channelID discord.ChannelID) ([]discord.Message, error) {
var pinned []discord.Message
return pinned, c.RequestJSON(&pinned, "GET", EndpointChannels+channelID.String()+"/pins")
}
@ -207,22 +217,21 @@ func (c *Client) PinnedMessages(channelID discord.Snowflake) ([]discord.Message,
// PinMessage pins a message in a channel.
//
// Requires the MANAGE_MESSAGES permission.
func (c *Client) PinMessage(channelID, messageID discord.Snowflake) error {
func (c *Client) PinMessage(channelID discord.ChannelID, messageID discord.MessageID) error {
return c.FastRequest("PUT", EndpointChannels+channelID.String()+"/pins/"+messageID.String())
}
// UnpinMessage deletes a pinned message in a channel.
//
// Requires the MANAGE_MESSAGES permission.
func (c *Client) UnpinMessage(channelID, messageID discord.Snowflake) error {
func (c *Client) UnpinMessage(channelID discord.ChannelID, messageID discord.MessageID) error {
return c.FastRequest("DELETE", EndpointChannels+channelID.String()+"/pins/"+messageID.String())
}
// AddRecipient adds a user to a group direct message. As accessToken is needed,
// clearly this endpoint should only be used for OAuth. AccessToken can be
// obtained with the "gdm.join" scope.
func (c *Client) AddRecipient(
channelID, userID discord.Snowflake, accessToken, nickname string) error {
func (c *Client) AddRecipient(channelID discord.ChannelID, userID discord.UserID, accessToken, nickname string) error {
var params struct {
AccessToken string `json:"access_token"`
@ -240,7 +249,7 @@ func (c *Client) AddRecipient(
}
// RemoveRecipient removes a user from a group direct message.
func (c *Client) RemoveRecipient(channelID, userID discord.Snowflake) error {
func (c *Client) RemoveRecipient(channelID discord.ChannelID, userID discord.UserID) error {
return c.FastRequest(
"DELETE",
EndpointChannels+channelID.String()+"/recipients/"+userID.String(),
@ -255,7 +264,7 @@ type Ack struct {
// Ack marks the read state of a channel. This is undocumented. The method will
// write to the ack variable passed in. If this method is called asynchronously,
// then ack should be mutex guarded.
func (c *Client) Ack(channelID, messageID discord.Snowflake, ack *Ack) error {
func (c *Client) Ack(channelID discord.ChannelID, messageID discord.MessageID, ack *Ack) error {
return c.RequestJSON(
ack, "POST",
EndpointChannels+channelID.String()+"/messages/"+messageID.String()+"/ack",

View file

@ -5,24 +5,25 @@ import (
"github.com/diamondburned/arikawa/utils/httputil"
)
// Emoji is the API format of a regular Emoji, both Unicode or custom.
// Emoji is the API format of a regular Emoji, both Unicode or custom. This
// could usually be formatted by calling (discord.Emoji).APIString().
type Emoji = string
// NewCustomEmoji creates a new Emoji using a custom guild emoji as
// base.
// Unicode emojis should be directly passed to the function using Emoji.
func NewCustomEmoji(id discord.Snowflake, name string) Emoji {
func NewCustomEmoji(id discord.EmojiID, name string) Emoji {
return name + ":" + id.String()
}
// Emojis returns a list of emoji objects for the given guild.
func (c *Client) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
var emjs []discord.Emoji
return emjs, c.RequestJSON(&emjs, "GET", EndpointGuilds+guildID.String()+"/emojis")
func (c *Client) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
var e []discord.Emoji
return e, c.RequestJSON(&e, "GET", EndpointGuilds+guildID.String()+"/emojis")
}
// Emoji returns an emoji object for the given guild and emoji IDs.
func (c *Client) Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error) {
func (c *Client) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
var emj *discord.Emoji
return emj, c.RequestJSON(&emj, "GET",
EndpointGuilds+guildID.String()+"/emojis/"+emojiID.String())
@ -35,7 +36,7 @@ type CreateEmojiData struct {
// Image is the the 128x128 emoji image.
Image Image `json:"image"`
// Roles are the roles for which this emoji will be whitelisted.
Roles *[]discord.Snowflake `json:"roles,omitempty"`
Roles *[]discord.RoleID `json:"roles,omitempty"`
}
// CreateEmoji creates a new emoji in the guild. This endpoint requires
@ -43,8 +44,7 @@ type CreateEmojiData struct {
// "image/gif". However, ContentType can also be automatically detected
// (though shouldn't be relied on).
// Emojis and animated emojis have a maximum file size of 256kb.
func (c *Client) CreateEmoji(
guildID discord.Snowflake, data CreateEmojiData) (*discord.Emoji, error) {
func (c *Client) CreateEmoji(guildID discord.GuildID, data CreateEmojiData) (*discord.Emoji, error) {
// Max 256KB
if err := data.Image.Validate(256 * 1000); err != nil {
@ -64,14 +64,14 @@ type ModifyEmojiData struct {
// Name is the name of the emoji.
Name string `json:"name,omitempty"`
// Roles are the roles to which this emoji will be whitelisted.
Roles *[]discord.Snowflake `json:"roles,omitempty"`
Roles *[]discord.RoleID `json:"roles,omitempty"`
}
// ModifyEmoji changes an existing emoji. This requires MANAGE_EMOJIS. Name and
// roles are optional fields (though you'd want to change either though).
//
// Fires a Guild Emojis Update Gateway event.
func (c *Client) ModifyEmoji(guildID, emojiID discord.Snowflake, data ModifyEmojiData) error {
func (c *Client) ModifyEmoji(guildID discord.GuildID, emojiID discord.EmojiID, data ModifyEmojiData) error {
return c.FastRequest(
"PATCH",
EndpointGuilds+guildID.String()+"/emojis/"+emojiID.String(),
@ -83,6 +83,6 @@ func (c *Client) ModifyEmoji(guildID, emojiID discord.Snowflake, data ModifyEmoj
//
// Requires the MANAGE_EMOJIS permission.
// Fires a Guild Emojis Update Gateway event.
func (c *Client) DeleteEmoji(guildID, emojiID discord.Snowflake) error {
func (c *Client) DeleteEmoji(guildID discord.GuildID, emojiID discord.EmojiID) error {
return c.FastRequest("DELETE", EndpointGuilds+guildID.String()+"/emojis/"+emojiID.String())
}

View file

@ -9,6 +9,10 @@ import (
"github.com/diamondburned/arikawa/utils/json/option"
)
// maxGuildFetchLimit is the limit of max guilds per request, as imposed by
// Discord.
const maxGuildFetchLimit = 100
var EndpointGuilds = Endpoint + "guilds/"
// https://discordapp.com/developers/docs/resources/guild#create-guild-json-params
@ -55,13 +59,13 @@ type CreateGuildData struct {
Channels []discord.Channel `json:"channels,omitempty"`
// AFKChannelID is the id for the afk channel.
AFKChannelID discord.Snowflake `json:"afk_channel_id,omitempty"`
AFKChannelID discord.ChannelID `json:"afk_channel_id,omitempty"`
// AFKTimeout is the afk timeout in seconds.
AFKTimeout option.Seconds `json:"afk_timeout,omitempty"`
// SystemChannelID is the id of the channel where guild notices such as
// welcome messages and boost events are posted.
SystemChannelID discord.Snowflake `json:"system_channel_id,omitempty"`
SystemChannelID discord.ChannelID `json:"system_channel_id,omitempty"`
}
// CreateGuild creates a new guild. Returns a guild object on success.
@ -75,15 +79,24 @@ func (c *Client) CreateGuild(data CreateGuildData) (*discord.Guild, error) {
// Guild returns the guild object for the given id.
// ApproximateMembers and ApproximatePresences will not be set.
func (c *Client) Guild(id discord.Snowflake) (*discord.Guild, error) {
func (c *Client) Guild(id discord.GuildID) (*discord.Guild, error) {
var g *discord.Guild
return g, c.RequestJSON(&g, "GET", EndpointGuilds+id.String())
}
// GuildPreview returns the guild preview object for the given id, even if the
// user is not in the guild.
//
// This endpoint is only for public guilds.
func (c *Client) GuildPreview(id discord.GuildID) (*discord.GuildPreview, error) {
var g *discord.GuildPreview
return g, c.RequestJSON(&g, "GET", EndpointGuilds+id.String()+"/preview")
}
// GuildWithCount returns the guild object for the given id.
// This will also set the ApproximateMembers and ApproximatePresences fields
// of the guild struct.
func (c *Client) GuildWithCount(id discord.Snowflake) (*discord.Guild, error) {
func (c *Client) GuildWithCount(id discord.GuildID) (*discord.Guild, error) {
var g *discord.Guild
return g, c.RequestJSON(
&g, "GET",
@ -94,71 +107,125 @@ func (c *Client) GuildWithCount(id discord.Snowflake) (*discord.Guild, error) {
)
}
// Guilds returns all guilds, automatically paginating. Be careful, as this
// method may abuse the API by requesting thousands or millions of guilds. For
// lower-level access, use GuildsRange. Guilds returned have some fields
// filled only (ID, Name, Icon, Owner, Permissions).
// Guilds returns a list of partial guild objects the current user is a member
// of. This method automatically paginates until it reaches the passed limit,
// or, if the limit is set to 0, has fetched all guilds the user has joined.
//
// Max can be 0, in which case the function will try and fetch all guilds.
func (c *Client) Guilds(max uint) ([]discord.Guild, error) {
var guilds []discord.Guild
var after discord.Snowflake = 0
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
//
// When fetching the guilds, those with the smallest ID will be fetched first.
//
// Also note that 100 is the maximum number of guilds a non-bot user can join.
// Therefore, pagination is not needed for integrations that need to get a list
// of the users' guilds.
//
// Requires the guilds OAuth2 scope.
func (c *Client) Guilds(limit uint) ([]discord.Guild, error) {
return c.GuildsAfter(0, limit)
}
const hardLimit int = 100
// GuildsBefore returns a list of partial guild objects the current user is a
// member of. This method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
// id smaller than before.
//
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
//
// Requires the guilds OAuth2 scope.
func (c *Client) GuildsBefore(before discord.GuildID, limit uint) ([]discord.Guild, error) {
guilds := make([]discord.Guild, 0, limit)
unlimited := max == 0
fetch := uint(maxGuildFetchLimit)
for fetch := uint(hardLimit); max > 0 || unlimited; fetch = uint(hardLimit) {
if max > 0 {
if fetch > max {
fetch = max
unlimited := limit == 0
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
max -= fetch
limit -= fetch
}
g, err := c.GuildsAfter(after, fetch)
g, err := c.guildsRange(before, 0, fetch)
if err != nil {
return guilds, err
}
guilds = append(guilds, g...)
guilds = append(g, guilds...)
if len(g) < hardLimit {
if len(g) < maxGuildFetchLimit {
break
}
after = g[hardLimit-1].ID
before = g[0].ID
}
if len(guilds) == 0 {
return nil, nil
}
return guilds, nil
}
// GuildsBefore fetches guilds before the specified ID. Check GuildsRange.
func (c *Client) GuildsBefore(before discord.Snowflake, limit uint) ([]discord.Guild, error) {
return c.GuildsRange(before, 0, limit)
}
// GuildsAfter fetches guilds after the specified ID. Check GuildsRange.
func (c *Client) GuildsAfter(after discord.Snowflake, limit uint) ([]discord.Guild, error) {
return c.GuildsRange(0, after, limit)
}
// GuildsRange returns a list of partial guild objects the current user is a
// member of. Requires the guilds OAuth2 scope.
// GuildsAfter returns a list of partial guild objects the current user is a
// member of. This method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all guilds with an
// id higher than after.
//
// This endpoint returns 100 guilds by default, which is the maximum number
// of guilds a non-bot user can join. Therefore, pagination is not needed
// for integrations that need to get a list of the users' guilds.
func (c *Client) GuildsRange(before, after discord.Snowflake, limit uint) ([]discord.Guild, error) {
switch {
case limit == 0:
limit = 100
case limit > 100:
limit = 100
// As the underlying endpoint has a maximum of 100 guilds per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
//
// Requires the guilds OAuth2 scope.
func (c *Client) GuildsAfter(after discord.GuildID, limit uint) ([]discord.Guild, error) {
guilds := make([]discord.Guild, 0, limit)
fetch := uint(maxGuildFetchLimit)
unlimited := limit == 0
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
}
limit -= fetch
}
g, err := c.guildsRange(0, after, fetch)
if err != nil {
return guilds, err
}
guilds = append(guilds, g...)
if len(g) < maxGuildFetchLimit {
break
}
after = g[len(g)-1].ID
}
if len(guilds) == 0 {
return nil, nil
}
return guilds, nil
}
func (c *Client) guildsRange(
before, after discord.GuildID, limit uint) ([]discord.Guild, error) {
var param struct {
Before discord.Snowflake `schema:"before,omitempty"`
After discord.Snowflake `schema:"after,omitempty"`
Before discord.GuildID `schema:"before,omitempty"`
After discord.GuildID `schema:"after,omitempty"`
Limit uint `schema:"limit"`
}
@ -176,7 +243,7 @@ func (c *Client) GuildsRange(before, after discord.Snowflake, limit uint) ([]dis
}
// LeaveGuild leaves a guild.
func (c *Client) LeaveGuild(id discord.Snowflake) error {
func (c *Client) LeaveGuild(id discord.GuildID) error {
return c.FastRequest("DELETE", EndpointMe+"/guilds/"+id.String())
}
@ -203,7 +270,7 @@ type ModifyGuildData struct {
// AFKChannelID is the id for the afk channel.
//
// This field is nullable.
AFKChannelID discord.Snowflake `json:"afk_channel_id,string,omitempty"`
AFKChannelID discord.ChannelID `json:"afk_channel_id,string,omitempty"`
// AFKTimeout is the afk timeout in seconds.
AFKTimeout option.Seconds `json:"afk_timeout,omitempty"`
// Icon is the base64 1024x1024 png/jpeg/gif image for the guild icon
@ -217,23 +284,23 @@ type ModifyGuildData struct {
Banner *Image `json:"banner,omitempty"`
// OwnerID is the user id to transfer guild ownership to (must be owner).
OwnerID discord.Snowflake `json:"owner_id,omitempty"`
OwnerID discord.UserID `json:"owner_id,omitempty"`
// SystemChannelID is the id of the channel where guild notices such as
// welcome messages and boost events are posted.
//
// This field is nullable.
SystemChannelID discord.Snowflake `json:"system_channel_id,omitempty"`
SystemChannelID discord.ChannelID `json:"system_channel_id,omitempty"`
// RulesChannelID is the id of the channel where "PUBLIC" guilds display
// rules and/or guidelines.
//
// This field is nullable.
RulesChannelID discord.Snowflake `json:"rules_channel_id,omitempty"`
RulesChannelID discord.ChannelID `json:"rules_channel_id,omitempty"`
// PublicUpdatesChannelID is the id of the channel where admins and
// moderators of "PUBLIC" guilds receive notices from Discord.
//
// This field is nullable.
PublicUpdatesChannelID discord.Snowflake `json:"public_updates_channel_id,omitempty"`
PublicUpdatesChannelID discord.ChannelID `json:"public_updates_channel_id,omitempty"`
// PreferredLocale is the preferred locale of a "PUBLIC" guild used in
// server discovery and notices from Discord.
@ -244,7 +311,7 @@ type ModifyGuildData struct {
// ModifyGuild modifies a guild's settings. Requires the MANAGE_GUILD permission.
// Fires a Guild Update Gateway event.
func (c *Client) ModifyGuild(id discord.Snowflake, data ModifyGuildData) (*discord.Guild, error) {
func (c *Client) ModifyGuild(id discord.GuildID, data ModifyGuildData) (*discord.Guild, error) {
var g *discord.Guild
return g, c.RequestJSON(
&g, "PATCH",
@ -257,13 +324,13 @@ func (c *Client) ModifyGuild(id discord.Snowflake, data ModifyGuildData) (*disco
// DeleteGuild deletes a guild permanently. The User must be owner.
//
// Fires a Guild Delete Gateway event.
func (c *Client) DeleteGuild(id discord.Snowflake) error {
func (c *Client) DeleteGuild(id discord.GuildID) error {
return c.FastRequest("DELETE", EndpointGuilds+id.String())
}
// GuildVoiceRegions is the same as /voice, but returns VIP ones as well if
// available.
func (c *Client) VoiceRegionsGuild(guildID discord.Snowflake) ([]discord.VoiceRegion, error) {
func (c *Client) VoiceRegionsGuild(guildID discord.GuildID) ([]discord.VoiceRegion, error) {
var vrs []discord.VoiceRegion
return vrs, c.RequestJSON(&vrs, "GET", EndpointGuilds+guildID.String()+"/regions")
}
@ -271,11 +338,11 @@ func (c *Client) VoiceRegionsGuild(guildID discord.Snowflake) ([]discord.VoiceRe
// https://discord.com/developers/docs/resources/audit-log#get-guild-audit-log-query-string-parameters
type AuditLogData struct {
// UserID filters the log for actions made by a user.
UserID discord.Snowflake `schema:"user_id,omitempty"`
UserID discord.UserID `schema:"user_id,omitempty"`
// ActionType is the type of audit log event.
ActionType discord.AuditLogEvent `schema:"action_type,omitempty"`
// Before filters the log before a certain entry ID.
Before discord.Snowflake `schema:"before,omitempty"`
Before discord.AuditLogEntryID `schema:"before,omitempty"`
// Limit limits how many entries are returned (default 50, minimum 1,
// maximum 100).
Limit uint `schema:"limit"`
@ -284,7 +351,7 @@ type AuditLogData struct {
// AuditLog returns an audit log object for the guild.
//
// Requires the VIEW_AUDIT_LOG permission.
func (c *Client) AuditLog(guildID discord.Snowflake, data AuditLogData) (*discord.AuditLog, error) {
func (c *Client) AuditLog(guildID discord.GuildID, data AuditLogData) (*discord.AuditLog, error) {
switch {
case data.Limit == 0:
data.Limit = 50
@ -304,7 +371,7 @@ func (c *Client) AuditLog(guildID discord.Snowflake, data AuditLogData) (*discor
// Integrations returns a list of integration objects for the guild.
//
// Requires the MANAGE_GUILD permission.
func (c *Client) Integrations(guildID discord.Snowflake) ([]discord.Integration, error) {
func (c *Client) Integrations(guildID discord.GuildID) ([]discord.Integration, error) {
var ints []discord.Integration
return ints, c.RequestJSON(&ints, "GET", EndpointGuilds+guildID.String()+"/integrations")
}
@ -314,12 +381,13 @@ func (c *Client) Integrations(guildID discord.Snowflake) ([]discord.Integration,
//
// Requires the MANAGE_GUILD permission.
// Fires a Guild Integrations Update Gateway event.
func (c *Client) AttachIntegration(guildID,
integrationID discord.Snowflake, integrationType discord.Service) error {
func (c *Client) AttachIntegration(
guildID discord.GuildID, integrationID discord.IntegrationID,
integrationType discord.Service) error {
var param struct {
Type discord.Service `json:"type"`
ID discord.Snowflake `json:"id"`
Type discord.Service `json:"type"`
ID discord.IntegrationID `json:"id"`
}
param.Type = integrationType
@ -351,7 +419,7 @@ type ModifyIntegrationData struct {
// Requires the MANAGE_GUILD permission.
// Fires a Guild Integrations Update Gateway event.
func (c *Client) ModifyIntegration(
guildID, integrationID discord.Snowflake, data ModifyIntegrationData) error {
guildID discord.GuildID, integrationID discord.IntegrationID, data ModifyIntegrationData) error {
return c.FastRequest(
"PATCH",
EndpointGuilds+guildID.String()+"/integrations/"+integrationID.String(),
@ -360,35 +428,41 @@ func (c *Client) ModifyIntegration(
}
// Sync an integration. Requires the MANAGE_GUILD permission.
func (c *Client) SyncIntegration(guildID, integrationID discord.Snowflake) error {
func (c *Client) SyncIntegration(guildID discord.GuildID, integrationID discord.IntegrationID) error {
return c.FastRequest(
"POST",
EndpointGuilds+guildID.String()+"/integrations/"+integrationID.String()+"/sync",
)
}
// GuildEmbed returns the guild embed object.
// GuildWidget returns the guild widget object.
//
// Requires the MANAGE_GUILD permission.
func (c *Client) GuildEmbed(guildID discord.Snowflake) (*discord.GuildEmbed, error) {
var ge *discord.GuildEmbed
return ge, c.RequestJSON(&ge, "GET", EndpointGuilds+guildID.String()+"/embed")
func (c *Client) GuildWidget(guildID discord.GuildID) (*discord.GuildWidget, error) {
var ge *discord.GuildWidget
return ge, c.RequestJSON(&ge, "GET", EndpointGuilds+guildID.String()+"/widget")
}
// https://discord.com/developers/docs/resources/guild#guild-embed-object-guild-embed-structure
type ModifyGuildEmbedData struct {
Enabled option.Bool `json:"enabled,omitempty"`
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
type ModifyGuildWidgetData struct {
// Enabled specifies whether the widget is enabled.
Enabled option.Bool `json:"enabled,omitempty"`
// ChannelID is the widget channel id.
ChannelID discord.ChannelID `json:"channel_id,omitempty"`
}
// ModifyGuildEmbed modifies the guild embed and updates the passed in
// GuildEmbed data.
// ModifyGuildWidget modifies a guild widget object for the guild.
//
// This method should be used with care: if you still want the embed enabled,
// you need to set the Enabled boolean, even if it's already enabled. If you
// don't, JSON will default it to false.
func (c *Client) ModifyGuildEmbed(guildID discord.Snowflake, data discord.GuildEmbed) error {
return c.RequestJSON(&data, "PATCH", EndpointGuilds+guildID.String()+"/embed")
// Requires the MANAGE_GUILD permission.
func (c *Client) ModifyGuildWidget(
guildID discord.GuildID, data ModifyGuildWidgetData) (*discord.GuildWidget, error) {
var w *discord.GuildWidget
return w, c.RequestJSON(
&w, "PATCH",
EndpointGuilds+guildID.String()+"/widget",
httputil.WithJSONBody(data),
)
}
// GuildVanityURL returns *Invite for guilds that have that feature enabled,
@ -396,7 +470,7 @@ func (c *Client) ModifyGuildEmbed(guildID discord.Snowflake, data discord.GuildE
// guild is not set.
//
// Requires MANAGE_GUILD.
func (c *Client) GuildVanityURL(guildID discord.Snowflake) (*discord.Invite, error) {
func (c *Client) GuildVanityURL(guildID discord.GuildID) (*discord.Invite, error) {
var inv *discord.Invite
return inv, c.RequestJSON(&inv, "GET", EndpointGuilds+guildID.String()+"/vanity-url")
}
@ -436,13 +510,13 @@ const (
// GuildImageURL returns a link to the PNG image widget for the guild.
//
// Requires no permissions or authentication.
func (c *Client) GuildImageURL(guildID discord.Snowflake, img GuildImageStyle) string {
func (c *Client) GuildImageURL(guildID discord.GuildID, img GuildImageStyle) string {
return EndpointGuilds + guildID.String() + "/widget.png?style=" + string(img)
}
// GuildImage returns a PNG image widget for the guild. Requires no permissions
// or authentication.
func (c *Client) GuildImage(guildID discord.Snowflake, img GuildImageStyle) (io.ReadCloser, error) {
func (c *Client) GuildImage(guildID discord.GuildID, img GuildImageStyle) (io.ReadCloser, error) {
r, err := c.Request("GET", c.GuildImageURL(guildID, img))
if err != nil {
return nil, err

View file

@ -10,9 +10,8 @@ import (
"github.com/pkg/errors"
)
var ErrInvalidImageCT = errors.New("Unknown image content-type")
var ErrInvalidImageData = errors.New("Invalid image data")
var ErrNoImage = errors.New("null image")
var ErrInvalidImageCT = errors.New("unknown image content-type")
var ErrInvalidImageData = errors.New("invalid image data")
type ErrImageTooLarge struct {
Size, Max int

View file

@ -14,7 +14,7 @@ import (
type testConfig struct {
BotToken string
ChannelID discord.Snowflake
ChannelID discord.ChannelID
}
func mustConfig(t *testing.T) testConfig {
@ -35,7 +35,7 @@ func mustConfig(t *testing.T) testConfig {
return testConfig{
BotToken: token,
ChannelID: id,
ChannelID: discord.ChannelID(id),
}
}
@ -85,7 +85,7 @@ func TestReactions(t *testing.T) {
client := NewClient("Bot " + cfg.BotToken)
msg := fmt.Sprint("This is a message sent at ", time.Now())
msg := fmt.Sprintf("This is a message sent at %v.", time.Now())
// Send a new message.
m, err := client.SendMessage(cfg.ChannelID, msg, nil)
@ -93,9 +93,18 @@ func TestReactions(t *testing.T) {
t.Fatal("Failed to send message:", err)
}
now := time.Now()
for _, emojiString := range emojisToSend {
if err := client.React(cfg.ChannelID, m.ID, emojiString); err != nil {
t.Fatal("Failed to send emoji "+emojiString+":", err)
}
}
msg += fmt.Sprintf(" Total time taken to send all reactions: %v.", time.Now().Sub(now))
m, err = client.EditMessage(cfg.ChannelID, m.ID, msg, nil, false)
if err != nil {
t.Fatal("Failed to edit message:", err)
}
}

View file

@ -39,7 +39,7 @@ func (c *Client) InviteWithCounts(code string) (*discord.Invite, error) {
// the channel. Only usable for guild channels.
//
// Requires the MANAGE_CHANNELS permission.
func (c *Client) ChannelInvites(channelID discord.Snowflake) ([]discord.Invite, error) {
func (c *Client) ChannelInvites(channelID discord.ChannelID) ([]discord.Invite, error) {
var invs []discord.Invite
return invs, c.RequestJSON(&invs, "GET",
EndpointChannels+channelID.String()+"/invites")
@ -49,7 +49,7 @@ func (c *Client) ChannelInvites(channelID discord.Snowflake) ([]discord.Invite,
// guild.
//
// Requires the MANAGE_GUILD permission.
func (c *Client) GuildInvites(guildID discord.Snowflake) ([]discord.Invite, error) {
func (c *Client) GuildInvites(guildID discord.GuildID) ([]discord.Invite, error) {
var invs []discord.Invite
return invs, c.RequestJSON(&invs, "GET",
EndpointGuilds+guildID.String()+"/invites")
@ -82,7 +82,7 @@ type CreateInviteData struct {
//
// Requires the CREATE_INSTANT_INVITE permission.
func (c *Client) CreateInvite(
channelID discord.Snowflake, data CreateInviteData) (*discord.Invite, error) {
channelID discord.ChannelID, data CreateInviteData) (*discord.Invite, error) {
var inv *discord.Invite
return inv, c.RequestJSON(
&inv, "POST",
@ -91,10 +91,11 @@ func (c *Client) CreateInvite(
)
}
// DeleteInvite deletes a channel permission overwrite for a user or role in a
// channel. Only usable for guild channels.
// DeleteInvite deletes an invite.
//
// Requires the MANAGE_ROLES permission.
// Requires the MANAGE_CHANNELS permission on the channel this invite belongs
// to, or MANAGE_GUILD to remove any invite across the guild.
// Fires a Invite Delete Gateway event.
func (c *Client) DeleteInvite(code string) (*discord.Invite, error) {
var inv *discord.Invite
return inv, c.RequestJSON(&inv, "DELETE", EndpointInvites+code)

View file

@ -6,53 +6,76 @@ import (
"github.com/diamondburned/arikawa/utils/json/option"
)
// Member returns a guild member object for the specified user..
func (c *Client) Member(guildID, userID discord.Snowflake) (*discord.Member, error) {
const maxMemberFetchLimit = 1000
// Member returns a guild member object for the specified user.
func (c *Client) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
var m *discord.Member
return m, c.RequestJSON(&m, "GET", EndpointGuilds+guildID.String()+"/members/"+userID.String())
}
// Members returns members until it reaches max. This function automatically
// paginates, meaning the normal 1000 limit is handled internally.
// Members returns a list of members of the guild with the passed id. This
// method automatically paginates until it reaches the passed limit, or, if the
// limit is set to 0, has fetched all members in the guild.
//
// Max can be 0, in which case the function will try and fetch all members.
func (c *Client) Members(guildID discord.Snowflake, max uint) ([]discord.Member, error) {
var mems []discord.Member
var after discord.Snowflake = 0
// As the underlying endpoint has a maximum of 1000 members per request, at
// maximum a total of limit/1000 rounded up requests will be made, although
// they may be less if no more members are available.
//
// When fetching the members, those with the smallest ID will be fetched first.
func (c *Client) Members(guildID discord.GuildID, limit uint) ([]discord.Member, error) {
return c.MembersAfter(guildID, 0, limit)
}
const hardLimit int = 1000
// MembersAfter returns a list of members of the guild with the passed id. This
// method automatically paginates until it reaches the passed limit, or, if the
// limit is set to 0, has fetched all members with an id higher than after.
//
// As the underlying endpoint has a maximum of 1000 members per request, at
// maximum a total of limit/1000 rounded up requests will be made, although
// they may be less, if no more members are available.
func (c *Client) MembersAfter(
guildID discord.GuildID, after discord.UserID, limit uint) ([]discord.Member, error) {
unlimited := max == 0
mems := make([]discord.Member, 0, limit)
for fetch := uint(hardLimit); max > 0 || unlimited; fetch = uint(hardLimit) {
if max > 0 {
if fetch > max {
fetch = max
fetch := uint(maxMemberFetchLimit)
unlimited := limit == 0
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
}
max -= fetch
limit -= fetch
}
m, err := c.MembersAfter(guildID, after, fetch)
m, err := c.membersAfter(guildID, after, fetch)
if err != nil {
return mems, err
}
mems = append(mems, m...)
// There aren't any to fetch, even if this is less than max.
if len(mems) < hardLimit {
// There aren't any to fetch, even if this is less than limit.
if len(m) < maxMemberFetchLimit {
break
}
after = mems[hardLimit-1].User.ID
after = mems[len(mems)-1].User.ID
}
if len(mems) == 0 {
return nil, nil
}
return mems, nil
}
// MembersAfter returns a list of all guild members, from 1-1000 for limits. The
// default limit is 1 and the maximum limit is 1000.
func (c *Client) MembersAfter(
guildID, after discord.Snowflake, limit uint) ([]discord.Member, error) {
func (c *Client) membersAfter(
guildID discord.GuildID, after discord.UserID, limit uint) ([]discord.Member, error) {
switch {
case limit == 0:
@ -62,8 +85,8 @@ func (c *Client) MembersAfter(
}
var param struct {
After discord.Snowflake `schema:"after,omitempty"`
Limit uint `schema:"limit"`
After discord.UserID `schema:"after,omitempty"`
Limit uint `schema:"limit"`
}
param.Limit = limit
@ -89,7 +112,7 @@ type AddMemberData struct {
// Roles is an array of role ids the member is assigned.
//
// Requires MANAGE_ROLES.
Roles *[]discord.Snowflake `json:"roles,omitempty"`
Roles *[]discord.RoleID `json:"roles,omitempty"`
// Mute specifies whether the user is muted in voice channels.
//
// Requires MUTE_MEMBERS.
@ -111,7 +134,7 @@ type AddMemberData struct {
// application used for authorization), and the bot must be a member of the
// guild with CREATE_INSTANT_INVITE permission.
func (c *Client) AddMember(
guildID, userID discord.Snowflake, data AddMemberData) (*discord.Member, error) {
guildID discord.GuildID, userID discord.UserID, data AddMemberData) (*discord.Member, error) {
var mem *discord.Member
return mem, c.RequestJSON(
&mem, "PUT",
@ -129,7 +152,7 @@ type ModifyMemberData struct {
// Roles is an array of role ids the member is assigned.
//
// Requires MANAGE_ROLES.
Roles *[]discord.Snowflake `json:"roles,omitempty"`
Roles *[]discord.RoleID `json:"roles,omitempty"`
// Mute specifies whether the user is muted in voice channels.
//
// Requires MUTE_MEMBERS.
@ -143,14 +166,14 @@ type ModifyMemberData struct {
// connected to voice).
//
// Requires MOVE_MEMBER
VoiceChannel discord.Snowflake `json:"channel_id,omitempty"`
VoiceChannel discord.ChannelID `json:"channel_id,omitempty"`
}
// ModifyMember modifies attributes of a guild member. If the channel_id is set
// to null, this will force the target user to be disconnected from voice.
//
// Fires a Guild Member Update Gateway event.
func (c *Client) ModifyMember(guildID, userID discord.Snowflake, data ModifyMemberData) error {
func (c *Client) ModifyMember(guildID discord.GuildID, userID discord.UserID, data ModifyMemberData) error {
return c.FastRequest(
"PATCH",
@ -159,21 +182,28 @@ func (c *Client) ModifyMember(guildID, userID discord.Snowflake, data ModifyMemb
)
}
// https://discord.com/developers/docs/resources/guild#get-guild-prune-count-query-string-params
type PruneCountData struct {
// Days is the number of days to count prune for (1 or more, default 7).
Days uint `schema:"days"`
// IncludedRoles are the role(s) to include.
IncludedRoles []discord.RoleID `schema:"include_roles,omitempty"`
}
// PruneCount returns the number of members that would be removed in a prune
// operation. Days must be 1 or more, default 7.
//
// By default, prune will not remove users with roles. You can optionally
// include specific roles in your prune by providing the IncludedRoles
// parameter. Any inactive user that has a subset of the provided role(s)
// will be counted in the prune and users with additional roles will not.
//
// Requires KICK_MEMBERS.
func (c *Client) PruneCount(guildID discord.Snowflake, days uint) (uint, error) {
if days == 0 {
days = 7
func (c *Client) PruneCount(guildID discord.GuildID, data PruneCountData) (uint, error) {
if data.Days == 0 {
data.Days = 7
}
var param struct {
Days uint `schema:"days"`
}
param.Days = days
var resp struct {
Pruned uint `json:"pruned"`
}
@ -181,50 +211,35 @@ func (c *Client) PruneCount(guildID discord.Snowflake, days uint) (uint, error)
return resp.Pruned, c.RequestJSON(
&resp, "GET",
EndpointGuilds+guildID.String()+"/prune",
httputil.WithSchema(c, param),
httputil.WithSchema(c, data),
)
}
// https://discord.com/developers/docs/resources/guild#begin-guild-prune-query-string-params
type PruneData struct {
// Days is the number of days to prune (1 or more, default 7).
Days uint `schema:"days"`
// ReturnCount specifies whether 'pruned' is returned. Discouraged for
// large guilds.
ReturnCount bool `schema:"compute_prune_count"`
// IncludedRoles are the role(s) to include.
IncludedRoles []discord.RoleID `schema:"include_roles,omitempty"`
}
// Prune begins a prune. Days must be 1 or more, default 7.
//
// Requires KICK_MEMBERS.
func (c *Client) Prune(guildID discord.Snowflake, days uint) error {
if days == 0 {
days = 7
}
var param struct {
Days uint `schema:"days"`
RetCount bool `schema:"compute_prune_count"`
}
param.Days = days
param.RetCount = false
return c.FastRequest(
"POST",
EndpointGuilds+guildID.String()+"/prune",
httputil.WithSchema(c, param),
)
}
// PruneWithCounts returns the number of members that is removed. Days must be 1 or more,
// default 7.
// By default, prune will not remove users with roles. You can optionally
// include specific roles in your prune by providing the IncludedRoles
// parameter. Any inactive user that has a subset of the provided role(s)
// will be included in the prune and users with additional roles will not.
//
// Requires KICK_MEMBERS.
func (c *Client) PruneWithCount(guildID discord.Snowflake, days uint) (uint, error) {
if days == 0 {
days = 7
// Fires multiple Guild Member Remove Gateway events.
func (c *Client) Prune(guildID discord.GuildID, data PruneData) (uint, error) {
if data.Days == 0 {
data.Days = 7
}
var param struct {
Days uint `schema:"days"`
RetCount bool `schema:"compute_prune_count"`
}
param.Days = days
param.RetCount = true
var resp struct {
Pruned uint `json:"pruned"`
}
@ -232,7 +247,7 @@ func (c *Client) PruneWithCount(guildID discord.Snowflake, days uint) (uint, err
return resp.Pruned, c.RequestJSON(
&resp, "POST",
EndpointGuilds+guildID.String()+"/prune",
httputil.WithSchema(c, param),
httputil.WithSchema(c, data),
)
}
@ -240,17 +255,35 @@ func (c *Client) PruneWithCount(guildID discord.Snowflake, days uint) (uint, err
//
// Requires KICK_MEMBERS permission.
// Fires a Guild Member Remove Gateway event.
func (c *Client) Kick(guildID, userID discord.Snowflake) error {
func (c *Client) Kick(guildID discord.GuildID, userID discord.UserID) error {
return c.KickWithReason(guildID, userID, "")
}
// KickWithReason removes a member from a guild.
// The reason, if non-empty, will be displayed in the audit log of the guild.
//
// Requires KICK_MEMBERS permission.
// Fires a Guild Member Remove Gateway event.
func (c *Client) KickWithReason(
guildID discord.GuildID, userID discord.UserID, reason string) error {
var data struct {
Reason string `schema:"reason,omitempty"`
}
data.Reason = reason
return c.FastRequest(
"DELETE",
EndpointGuilds+guildID.String()+"/members/"+userID.String(),
httputil.WithSchema(c, data),
)
}
// Bans returns a list of ban objects for the users banned from this guild.
//
// Requires the BAN_MEMBERS permission.
func (c *Client) Bans(guildID discord.Snowflake) ([]discord.Ban, error) {
func (c *Client) Bans(guildID discord.GuildID) ([]discord.Ban, error) {
var bans []discord.Ban
return bans, c.RequestJSON(
&bans, "GET",
@ -261,7 +294,7 @@ func (c *Client) Bans(guildID discord.Snowflake) ([]discord.Ban, error) {
// GetBan returns a ban object for the given user.
//
// Requires the BAN_MEMBERS permission.
func (c *Client) GetBan(guildID, userID discord.Snowflake) (*discord.Ban, error) {
func (c *Client) GetBan(guildID discord.GuildID, userID discord.UserID) (*discord.Ban, error) {
var ban *discord.Ban
return ban, c.RequestJSON(
&ban, "GET",
@ -281,11 +314,7 @@ type BanData struct {
// banned user.
//
// Requires the BAN_MEMBERS permission.
func (c *Client) Ban(guildID, userID discord.Snowflake, data BanData) error {
if *data.DeleteDays > 7 {
*data.DeleteDays = 7
}
func (c *Client) Ban(guildID discord.GuildID, userID discord.UserID, data BanData) error {
return c.FastRequest(
"PUT",
EndpointGuilds+guildID.String()+"/bans/"+userID.String(),
@ -297,6 +326,6 @@ func (c *Client) Ban(guildID, userID discord.Snowflake, data BanData) error {
//
// Requires the BAN_MEMBERS permissions.
// Fires a Guild Ban Remove Gateway event.
func (c *Client) Unban(guildID, userID discord.Snowflake) error {
func (c *Client) Unban(guildID discord.GuildID, userID discord.UserID) error {
return c.FastRequest("DELETE", EndpointGuilds+guildID.String()+"/bans/"+userID.String())
}

View file

@ -1,72 +1,155 @@
package api
import (
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/diamondburned/arikawa/utils/json/option"
)
// Messages gets all messages, automatically paginating. Use with care, as
// this could get as many as hundred thousands of messages, making a lot of
// queries.
// the limit of max messages per request, as imposed by Discord
const maxMessageFetchLimit = 100
// Messages returns a slice filled with the most recent messages sent in the
// channel with the passed ID. The method automatically paginates until it
// reaches the passed limit, or, if the limit is set to 0, has fetched all
// messages in the channel.
//
// Max can be 0, in which case the function will try and fetch all messages.
func (c *Client) Messages(channelID discord.Snowflake, max uint) ([]discord.Message, error) {
var msgs []discord.Message
var after discord.Snowflake = 0
// As the underlying endpoint is capped at a maximum of 100 messages per
// request, at maximum a total of limit/100 rounded up requests will be made,
// although they may be less, if no more messages are available.
//
// When fetching the messages, those with the highest ID, will be fetched
// first.
// The returned slice will be sorted from latest to oldest.
func (c *Client) Messages(channelID discord.ChannelID, limit uint) ([]discord.Message, error) {
// Since before is 0 it will be omitted by the http lib, which in turn
// will lead discord to send us the most recent messages without having to
// specify a Snowflake.
return c.MessagesBefore(channelID, 0, limit)
}
const hardLimit int = 100
// MessagesAround returns messages around the ID, with a limit of 100.
func (c *Client) MessagesAround(
channelID discord.ChannelID, around discord.MessageID, limit uint) ([]discord.Message, error) {
unlimited := max == 0
return c.messagesRange(channelID, 0, 0, around, limit)
}
for fetch := uint(hardLimit); max > 0 || unlimited; fetch = uint(hardLimit) {
if max > 0 {
if fetch > max {
fetch = max
// MessagesBefore returns a slice filled with the messages sent in the channel
// with the passed id. The method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all messages in the
// channel with an id smaller than before.
//
// As the underlying endpoint has a maximum of 100 messages per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more messages are available.
//
// The returned slice will be sorted from latest to oldest.
func (c *Client) MessagesBefore(
channelID discord.ChannelID, before discord.MessageID, limit uint) ([]discord.Message, error) {
msgs := make([]discord.Message, 0, limit)
fetch := uint(maxMessageFetchLimit)
// Check if we are truly fetching unlimited messages to avoid confusion
// later on, if the limit reaches 0.
unlimited := limit == 0
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
max -= fetch
limit -= maxMessageFetchLimit
}
m, err := c.messagesRange(channelID, before, 0, 0, fetch)
if err != nil {
return msgs, err
}
// Append the older messages into the list of newer messages.
msgs = append(msgs, m...)
if len(m) < maxMessageFetchLimit {
break
}
before = m[len(m)-1].ID
}
if len(msgs) == 0 {
return nil, nil
}
return msgs, nil
}
// MessagesAfter returns a slice filled with the messages sent in the channel
// with the passed ID. The method automatically paginates until it reaches the
// passed limit, or, if the limit is set to 0, has fetched all messages in the
// channel with an id higher than after.
//
// As the underlying endpoint has a maximum of 100 messages per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more messages are available.
//
// The returned slice will be sorted from latest to oldest.
func (c *Client) MessagesAfter(
channelID discord.ChannelID, after discord.MessageID, limit uint) ([]discord.Message, error) {
// 0 is uint's zero value and will lead to the after param getting omitted,
// which in turn will lead to the most recent messages being returned.
// Setting this to 1 will prevent that.
if after == 0 {
after = 1
}
var msgs []discord.Message
fetch := uint(maxMessageFetchLimit)
// Check if we are truly fetching unlimited messages to avoid confusion
// later on, if the limit reaches 0.
unlimited := limit == 0
for limit > 0 || unlimited {
if limit > 0 {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if fetch > limit {
fetch = limit
}
limit -= maxMessageFetchLimit
}
m, err := c.messagesRange(channelID, 0, after, 0, fetch)
if err != nil {
return msgs, err
}
msgs = append(msgs, m...)
// Prepend the older messages into the newly-fetched messages list.
msgs = append(m, msgs...)
if len(m) < hardLimit {
if len(m) < maxMessageFetchLimit {
break
}
after = m[hardLimit-1].Author.ID
after = m[0].ID
}
if len(msgs) == 0 {
return nil, nil
}
return msgs, nil
}
// MessagesAround returns messages around the ID, with a limit of 1-100.
func (c *Client) MessagesAround(
channelID, around discord.Snowflake, limit uint) ([]discord.Message, error) {
return c.messagesRange(channelID, 0, 0, around, limit)
}
// MessagesBefore returns messages before the ID, with a limit of 1-100.
func (c *Client) MessagesBefore(
channelID, before discord.Snowflake, limit uint) ([]discord.Message, error) {
return c.messagesRange(channelID, before, 0, 0, limit)
}
// MessagesAfter returns messages after the ID, with a limit of 1-100.
func (c *Client) MessagesAfter(
channelID, after discord.Snowflake, limit uint) ([]discord.Message, error) {
return c.messagesRange(channelID, 0, after, 0, limit)
}
func (c *Client) messagesRange(
channelID, before, after, around discord.Snowflake,
limit uint) ([]discord.Message, error) {
channelID discord.ChannelID, before, after, around discord.MessageID, limit uint) ([]discord.Message, error) {
switch {
case limit == 0:
@ -76,9 +159,9 @@ func (c *Client) messagesRange(
}
var param struct {
Before discord.Snowflake `schema:"before,omitempty"`
After discord.Snowflake `schema:"after,omitempty"`
Around discord.Snowflake `schema:"around,omitempty"`
Before discord.MessageID `schema:"before,omitempty"`
After discord.MessageID `schema:"after,omitempty"`
Around discord.MessageID `schema:"around,omitempty"`
Limit uint `schema:"limit"`
}
@ -100,7 +183,7 @@ func (c *Client) messagesRange(
//
// If operating on a guild channel, this endpoint requires the
// READ_MESSAGE_HISTORY permission to be present on the current user.
func (c *Client) Message(channelID, messageID discord.Snowflake) (*discord.Message, error) {
func (c *Client) Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
var msg *discord.Message
return msg, c.RequestJSON(&msg, "GET",
EndpointChannels+channelID.String()+"/messages/"+messageID.String())
@ -112,7 +195,7 @@ func (c *Client) Message(channelID, messageID discord.Snowflake) (*discord.Messa
// permission to be present on the current user.
//
// Fires a Message Create Gateway event.
func (c *Client) SendText(channelID discord.Snowflake, content string) (*discord.Message, error) {
func (c *Client) SendText(channelID discord.ChannelID, content string) (*discord.Message, error) {
return c.SendMessageComplex(channelID, SendMessageData{
Content: content,
})
@ -125,7 +208,7 @@ func (c *Client) SendText(channelID discord.Snowflake, content string) (*discord
//
// Fires a Message Create Gateway event.
func (c *Client) SendEmbed(
channelID discord.Snowflake, e discord.Embed) (*discord.Message, error) {
channelID discord.ChannelID, e discord.Embed) (*discord.Message, error) {
return c.SendMessageComplex(channelID, SendMessageData{
Embed: &e,
@ -139,7 +222,7 @@ func (c *Client) SendEmbed(
//
// Fires a Message Create Gateway event.
func (c *Client) SendMessage(
channelID discord.Snowflake, content string, embed *discord.Embed) (*discord.Message, error) {
channelID discord.ChannelID, content string, embed *discord.Embed) (*discord.Message, error) {
return c.SendMessageComplex(channelID, SendMessageData{
Content: content,
@ -152,7 +235,8 @@ type EditMessageData struct {
// Content is the new message contents (up to 2000 characters).
Content option.NullableString `json:"content,omitempty"`
// Embed contains embedded rich content.
Embed *discord.Embed `json:"embed,omitempty"`
Embed *discord.Embed `json:"embed,omitempty"`
// AllowedMentions are the allowed mentions for a message.
AllowedMentions *AllowedMentions `json:"allowed_mentions,omitempty"`
// Flags edits the flags of a message (only SUPPRESS_EMBEDS can currently
// be set/unset)
@ -161,17 +245,67 @@ type EditMessageData struct {
Flags *discord.MessageFlags `json:"flags,omitempty"`
}
// Edit a previously sent message. The fields Content, Embed,
// AllowedMentions and Flags can be edited by the original message author.
// Other users can only edit flags and only if they have the MANAGE_MESSAGES
// permission in the corresponding channel. When specifying flags, ensure to
// include all previously set flags/bits in addition to ones that you are
// modifying. Only flags documented in EditMessageData may be modified by users
// (unsupported flag changes are currently ignored without error).
// EditText edits the contents of a previously sent message. For more
// documentation, refer to EditMessageComplex.
func (c *Client) EditText(
channelID discord.ChannelID, messageID discord.MessageID, content string) (*discord.Message, error) {
return c.EditMessageComplex(channelID, messageID, EditMessageData{
Content: option.NewNullableString(content),
})
}
// EditEmbed edits the embed of a previously sent message. For more
// documentation, refer to EditMessageComplex.
func (c *Client) EditEmbed(
channelID discord.ChannelID, messageID discord.MessageID, embed discord.Embed) (*discord.Message, error) {
return c.EditMessageComplex(channelID, messageID, EditMessageData{
Embed: &embed,
})
}
// EditMessage edits a previously sent message. For more documentation, refer to
// EditMessageComplex.
func (c *Client) EditMessage(
channelID discord.ChannelID, messageID discord.MessageID, content string,
embed *discord.Embed, suppressEmbeds bool) (*discord.Message, error) {
var data = EditMessageData{
Content: option.NewNullableString(content),
Embed: embed,
}
if suppressEmbeds {
data.Flags = &discord.SuppressEmbeds
}
return c.EditMessageComplex(channelID, messageID, data)
}
// EditMessageComplex edits a previously sent message. The fields Content,
// Embed, AllowedMentions and Flags can be edited by the original message
// author. Other users can only edit flags and only if they have the
// MANAGE_MESSAGES permission in the corresponding channel. When specifying
// flags, ensure to include all previously set flags/bits in addition to ones
// that you are modifying. Only flags documented in EditMessageData may be
// modified by users (unsupported flag changes are currently ignored without
// error).
//
// Fires a Message Update Gateway event.
func (c *Client) EditMessage(
channelID, messageID discord.Snowflake, data EditMessageData) (*discord.Message, error) {
func (c *Client) EditMessageComplex(
channelID discord.ChannelID, messageID discord.MessageID, data EditMessageData) (*discord.Message, error) {
if data.AllowedMentions != nil {
if err := data.AllowedMentions.Verify(); err != nil {
return nil, errors.Wrap(err, "allowedMentions error")
}
}
if data.Embed != nil {
if err := data.Embed.Validate(); err != nil {
return nil, errors.Wrap(err, "embed error")
}
}
var msg *discord.Message
return msg, c.RequestJSON(
@ -184,7 +318,7 @@ func (c *Client) EditMessage(
// DeleteMessage delete a message. If operating on a guild channel and trying
// to delete a message that was not sent by the current user, this endpoint
// requires the MANAGE_MESSAGES permission.
func (c *Client) DeleteMessage(channelID, messageID discord.Snowflake) error {
func (c *Client) DeleteMessage(channelID discord.ChannelID, messageID discord.MessageID) error {
return c.FastRequest("DELETE", EndpointChannels+channelID.String()+
"/messages/"+messageID.String())
}
@ -198,9 +332,9 @@ func (c *Client) DeleteMessage(channelID, messageID discord.Snowflake) error {
// provided.
//
// Fires a Message Delete Bulk Gateway event.
func (c *Client) DeleteMessages(channelID discord.Snowflake, messageIDs []discord.Snowflake) error {
func (c *Client) DeleteMessages(channelID discord.ChannelID, messageIDs []discord.MessageID) error {
var param struct {
Messages []discord.Snowflake `json:"messages"`
Messages []discord.MessageID `json:"messages"`
}
param.Messages = messageIDs

View file

@ -7,13 +7,15 @@ import (
"github.com/diamondburned/arikawa/utils/httputil"
)
const maxMessageReactionFetchLimit = 100
// React creates a reaction for the message.
//
// This endpoint requires the READ_MESSAGE_HISTORY permission to be present on
// the current user. Additionally, if nobody else has reacted to the message
// using this emoji, this endpoint requires the 'ADD_REACTIONS' permission to
// be present on the current user.
func (c *Client) React(channelID, messageID discord.Snowflake, emoji Emoji) error {
func (c *Client) React(channelID discord.ChannelID, messageID discord.MessageID, emoji Emoji) error {
var msgURL = EndpointChannels + channelID.String() +
"/messages/" + messageID.String() +
"/reactions/" + url.PathEscape(emoji) + "/@me"
@ -21,67 +23,126 @@ func (c *Client) React(channelID, messageID discord.Snowflake, emoji Emoji) erro
}
// Unreact removes a reaction the current user has made for the message.
func (c *Client) Unreact(chID, msgID discord.Snowflake, emoji Emoji) error {
func (c *Client) Unreact(chID discord.ChannelID, msgID discord.MessageID, emoji Emoji) error {
return c.DeleteUserReaction(chID, msgID, 0, emoji)
}
// Reactions returns reactions up to the specified limit. It will paginate
// automatically.
// Reactions returns a list of users that reacted with the passed Emoji. This
// method automatically paginates until it reaches the passed limit, or, if the
// limit is set to 0, has fetched all users within the passed range.
//
// Max can be 0, in which case the function will try and fetch all reactions.
// As the underlying endpoint has a maximum of 100 users per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
//
// When fetching the users, those with the smallest ID will be fetched first.
func (c *Client) Reactions(
channelID, messageID discord.Snowflake, max uint, emoji Emoji) ([]discord.User, error) {
channelID discord.ChannelID, messageID discord.MessageID, emoji Emoji, limit uint) ([]discord.User, error) {
var users []discord.User
var after discord.Snowflake = 0
return c.ReactionsAfter(channelID, messageID, 0, emoji, limit)
}
const hardLimit int = 100
// ReactionsBefore returns a list of users that reacted with the passed Emoji.
// This method automatically paginates until it reaches the passed limit, or,
// if the limit is set to 0, has fetched all users with an id smaller than
// before.
//
// As the underlying endpoint has a maximum of 100 users per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
func (c *Client) ReactionsBefore(
channelID discord.ChannelID, messageID discord.MessageID, before discord.UserID, emoji Emoji,
limit uint) ([]discord.User, error) {
for fetch := uint(hardLimit); max > 0; fetch = uint(hardLimit) {
if max > 0 {
if fetch > max {
fetch = max
users := make([]discord.User, 0, limit)
fetch := uint(maxMessageReactionFetchLimit)
unlimited := limit == 0
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
}
max -= fetch
limit -= fetch
}
r, err := c.ReactionsRange(channelID, messageID, 0, after, fetch, emoji)
r, err := c.reactionsRange(channelID, messageID, before, 0, emoji, fetch)
if err != nil {
return users, err
}
users = append(users, r...)
users = append(r, users...)
if len(r) < hardLimit {
if len(r) < maxMessageReactionFetchLimit {
break
}
after = r[hardLimit-1].ID
before = r[0].ID
}
if len(users) == 0 {
return nil, nil
}
return users, nil
}
// ReactionsBefore gets all reactions before the passed user ID.
func (c *Client) ReactionsBefore(
channelID, messageID, before discord.Snowflake,
limit uint, emoji Emoji) ([]discord.User, error) {
return c.ReactionsRange(channelID, messageID, before, 0, limit, emoji)
}
// Refer to ReactionsRange.
// ReactionsAfter returns a list of users that reacted with the passed Emoji.
// This method automatically paginates until it reaches the passed limit, or,
// if the limit is set to 0, has fetched all users with an id higher than
// after.
//
// As the underlying endpoint has a maximum of 100 users per request, at
// maximum a total of limit/100 rounded up requests will be made, although they
// may be less, if no more guilds are available.
func (c *Client) ReactionsAfter(
channelID, messageID, after discord.Snowflake,
limit uint, emoji Emoji) ([]discord.User, error) {
channelID discord.ChannelID, messageID discord.MessageID, after discord.UserID, emoji Emoji,
limit uint) ([]discord.User, error) {
return c.ReactionsRange(channelID, messageID, 0, after, limit, emoji)
users := make([]discord.User, 0, limit)
fetch := uint(maxMessageReactionFetchLimit)
unlimited := limit == 0
for limit > 0 || unlimited {
// Only fetch as much as we need. Since limit gradually decreases,
// we only need to fetch min(fetch, limit).
if limit > 0 {
if fetch > limit {
fetch = limit
}
limit -= fetch
}
r, err := c.reactionsRange(channelID, messageID, 0, after, emoji, fetch)
if err != nil {
return users, err
}
users = append(users, r...)
if len(r) < maxMessageReactionFetchLimit {
break
}
after = r[len(r)-1].ID
}
if len(users) == 0 {
return nil, nil
}
return users, nil
}
// ReactionsRange get users before and after IDs. Before, after, and limit are
// reactionsRange get users before and after IDs. Before, after, and limit are
// optional. A maximum limit of only 100 reactions could be returned.
func (c *Client) ReactionsRange(
channelID, messageID, before, after discord.Snowflake,
limit uint, emoji Emoji) ([]discord.User, error) {
func (c *Client) reactionsRange(
channelID discord.ChannelID, messageID discord.MessageID, before, after discord.UserID, emoji Emoji,
limit uint) ([]discord.User, error) {
switch {
case limit == 0:
@ -91,8 +152,8 @@ func (c *Client) ReactionsRange(
}
var param struct {
Before discord.Snowflake `schema:"before,omitempty"`
After discord.Snowflake `schema:"after,omitempty"`
Before discord.UserID `schema:"before,omitempty"`
After discord.UserID `schema:"after,omitempty"`
Limit uint `schema:"limit"`
}
@ -115,7 +176,7 @@ func (c *Client) ReactionsRange(
// This endpoint requires the MANAGE_MESSAGES permission to be present on the
// current user.
func (c *Client) DeleteUserReaction(
channelID, messageID, userID discord.Snowflake, emoji Emoji) error {
channelID discord.ChannelID, messageID discord.MessageID, userID discord.UserID, emoji Emoji) error {
var user = "@me"
if userID > 0 {
@ -135,11 +196,11 @@ func (c *Client) DeleteUserReaction(
// current user.
// Fires a Message Reaction Remove Emoji Gateway event.
func (c *Client) DeleteReactions(
channelId, messageID discord.Snowflake, emoji Emoji) error {
channelID discord.ChannelID, messageID discord.MessageID, emoji Emoji) error {
return c.FastRequest(
"DELETE",
EndpointChannels+channelId.String()+"/messages/"+messageID.String()+
EndpointChannels+channelID.String()+"/messages/"+messageID.String()+
"/reactions/"+url.PathEscape(emoji),
)
}
@ -149,7 +210,7 @@ func (c *Client) DeleteReactions(
// This endpoint requires the MANAGE_MESSAGES permission to be present on the
// current user.
// Fires a Message Reaction Remove All Gateway event.
func (c *Client) DeleteAllReactions(channelID, messageID discord.Snowflake) error {
func (c *Client) DeleteAllReactions(channelID discord.ChannelID, messageID discord.MessageID) error {
return c.FastRequest(
"DELETE",
EndpointChannels+channelID.String()+"/messages/"+messageID.String()+"/reactions/",

View file

@ -9,8 +9,9 @@ import (
"sync/atomic"
"time"
"github.com/diamondburned/arikawa/utils/moreatomic"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/internal/moreatomic"
)
// ExtraDelay because Discord is trash. I've seen this in both litcord and
@ -27,9 +28,8 @@ type Limiter struct {
Prefix string
global *int64 // atomic guarded, unixnano
buckets sync.Map
globalRate time.Duration
global *int64 // atomic guarded, unixnano
buckets sync.Map
}
type CustomRateLimit struct {
@ -42,7 +42,6 @@ type bucket struct {
custom *CustomRateLimit
remaining uint64
limit uint
reset time.Time
lastReset time.Time // only for custom
@ -101,7 +100,7 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
if b.remaining == 0 && b.reset.After(time.Now()) {
// out of turns, gotta wait
sleep = b.reset.Sub(time.Now())
sleep = time.Until(b.reset)
} else {
// maybe global rate limit has it
now := time.Now()
@ -129,14 +128,15 @@ func (l *Limiter) Acquire(ctx context.Context, path string) error {
}
// Release releases the URL from the locks. This doesn't need a context for
// timing out, it doesn't block that much.
// timing out, since it doesn't block that much.
func (l *Limiter) Release(path string, headers http.Header) error {
b := l.getBucket(path, false)
if b == nil {
return nil
}
defer b.lock.Unlock()
// TryUnlock because Release may be called when Acquire has not been.
defer b.lock.TryUnlock()
// Check custom limiter
if b.custom != nil {
@ -169,7 +169,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
case retryAfter != "":
i, err := strconv.Atoi(retryAfter)
if err != nil {
return errors.Wrap(err, "Invalid retryAfter "+retryAfter)
return errors.Wrap(err, "invalid retryAfter "+retryAfter)
}
at := time.Now().Add(time.Duration(i) * time.Millisecond)
@ -183,7 +183,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
case reset != "":
unix, err := strconv.ParseFloat(reset, 64)
if err != nil {
return errors.Wrap(err, "Invalid reset "+reset)
return errors.Wrap(err, "invalid reset "+reset)
}
b.reset = time.Unix(0, int64(unix*float64(time.Second))).
@ -193,7 +193,7 @@ func (l *Limiter) Release(path string, headers http.Header) error {
if remaining != "" {
u, err := strconv.ParseUint(remaining, 10, 64)
if err != nil {
return errors.Wrap(err, "Invalid remaining "+remaining)
return errors.Wrap(err, "invalid remaining "+remaining)
}
b.remaining = u

View file

@ -46,7 +46,7 @@ func TestRatelimitReset(t *testing.T) {
if time.Since(sent) >= time.Second && time.Since(sent) < time.Second*4 {
t.Log("OK", time.Since(sent))
} else {
t.Error("Did not ratelimit correctly, got:", time.Since(sent))
t.Error("did not ratelimit correctly, got:", time.Since(sent))
}
}
@ -71,6 +71,6 @@ func TestRatelimitGlobal(t *testing.T) {
if time.Since(sent) >= time.Second && time.Since(sent) < time.Second*2 {
t.Log("OK", time.Since(sent))
} else {
t.Error("Did not ratelimit correctly, got:", time.Since(sent))
t.Error("did not ratelimit correctly, got:", time.Since(sent))
}
}

View file

@ -9,7 +9,7 @@ import (
// Adds a role to a guild member.
//
// Requires the MANAGE_ROLES permission.
func (c *Client) AddRole(guildID, userID, roleID discord.Snowflake) error {
func (c *Client) AddRole(guildID discord.GuildID, userID discord.UserID, roleID discord.RoleID) error {
return c.FastRequest(
"PUT",
EndpointGuilds+guildID.String()+"/members/"+userID.String()+"/roles/"+roleID.String(),
@ -20,7 +20,7 @@ func (c *Client) AddRole(guildID, userID, roleID discord.Snowflake) error {
//
// Requires the MANAGE_ROLES permission.
// Fires a Guild Member Update Gateway event.
func (c *Client) RemoveRole(guildID, userID, roleID discord.Snowflake) error {
func (c *Client) RemoveRole(guildID discord.GuildID, userID discord.UserID, roleID discord.RoleID) error {
return c.FastRequest(
"DELETE",
EndpointGuilds+guildID.String()+"/members/"+userID.String()+"/roles/"+roleID.String(),
@ -28,7 +28,7 @@ func (c *Client) RemoveRole(guildID, userID, roleID discord.Snowflake) error {
}
// Roles returns a list of role objects for the guild.
func (c *Client) Roles(guildID discord.Snowflake) ([]discord.Role, error) {
func (c *Client) Roles(guildID discord.GuildID) ([]discord.Role, error) {
var roles []discord.Role
return roles, c.RequestJSON(&roles, "GET", EndpointGuilds+guildID.String()+"/roles")
}
@ -42,7 +42,7 @@ type CreateRoleData struct {
// Permissions is the bitwise value of the enabled/disabled permissions.
//
// Default: @everyone permissions in guild
Permissions discord.Permissions `json:"permissions,omitempty"`
Permissions discord.Permissions `json:"permissions,omitempty,string"`
// Color is the RGB color value of the role.
//
// Default: 0
@ -62,8 +62,7 @@ type CreateRoleData struct {
//
// Requires the MANAGE_ROLES permission.
// Fires a Guild Role Create Gateway event.
func (c *Client) CreateRole(
guildID discord.Snowflake, data CreateRoleData) (*discord.Role, error) {
func (c *Client) CreateRole(guildID discord.GuildID, data CreateRoleData) (*discord.Role, error) {
var role *discord.Role
return role, c.RequestJSON(
@ -76,7 +75,7 @@ func (c *Client) CreateRole(
// https://discord.com/developers/docs/resources/guild#modify-guild-role-positions-json-params
type MoveRoleData struct {
// ID is the id of the role.
ID discord.Snowflake `json:"id"`
ID discord.RoleID `json:"id"`
// Position is the sorting position of the role.
Position option.NullableInt `json:"position,omitempty"`
}
@ -85,7 +84,7 @@ type MoveRoleData struct {
//
// Requires the MANAGE_ROLES permission.
// Fires multiple Guild Role Update Gateway events.
func (c *Client) MoveRole(guildID discord.Snowflake, data []MoveRoleData) ([]discord.Role, error) {
func (c *Client) MoveRole(guildID discord.GuildID, data []MoveRoleData) ([]discord.Role, error) {
var roles []discord.Role
return roles, c.RequestJSON(
&roles, "PATCH",
@ -99,7 +98,7 @@ type ModifyRoleData struct {
// Name is the name of the role.
Name option.NullableString `json:"name,omitempty"`
// Permissions is the bitwise value of the enabled/disabled permissions.
Permissions *discord.Permissions `json:"permissions,omitempty"`
Permissions *discord.Permissions `json:"permissions,omitempty,string"`
// Permissions is the bitwise value of the enabled/disabled permissions.
Color option.NullableColor `json:"color,omitempty"`
// Hoist specifies whether the role should be displayed separately in the
@ -113,7 +112,7 @@ type ModifyRoleData struct {
//
// Requires the MANAGE_ROLES permission.
func (c *Client) ModifyRole(
guildID, roleID discord.Snowflake,
guildID discord.GuildID, roleID discord.RoleID,
data ModifyRoleData) (*discord.Role, error) {
var role *discord.Role
@ -127,7 +126,7 @@ func (c *Client) ModifyRole(
// DeleteRole deletes a guild role.
//
// Requires the MANAGE_ROLES permission.
func (c *Client) DeleteRole(guildID, roleID discord.Snowflake) error {
func (c *Client) DeleteRole(guildID discord.GuildID, roleID discord.RoleID) error {
return c.FastRequest(
"DELETE",
EndpointGuilds+guildID.String()+"/roles/"+roleID.String(),

View file

@ -3,20 +3,17 @@ package api
import (
"io"
"mime/multipart"
"net/url"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/diamondburned/arikawa/utils/json"
"github.com/pkg/errors"
)
const AttachmentSpoilerPrefix = "SPOILER_"
var quoteEscaper = strings.NewReplacer(`\`, `\\`, `"`, `\"`)
// AllowedMentions is a whitelist of mentions for a message.
// https://discordapp.com/developers/docs/resources/channel#allowed-mentions-object
//
@ -41,9 +38,9 @@ type AllowedMentions struct {
// Parse is an array of allowed mention types to parse from the content.
Parse []AllowedMentionType `json:"parse"`
// Roles is an array of role_ids to mention (Max size of 100).
Roles []discord.Snowflake `json:"roles,omitempty"`
Roles []discord.RoleID `json:"roles,omitempty"`
// Users is an array of user_ids to mention (Max size of 100).
Users []discord.Snowflake `json:"users,omitempty"`
Users []discord.UserID `json:"users,omitempty"`
}
// AllowedMentionType is a constant that tells Discord what is allowed to parse
@ -64,21 +61,21 @@ const (
// AllowedMentions' documentation. This will be called on SendMessageComplex.
func (am AllowedMentions) Verify() error {
if len(am.Roles) > 100 {
return errors.Errorf("Roles slice length %d is over 100", len(am.Roles))
return errors.Errorf("roles slice length %d is over 100", len(am.Roles))
}
if len(am.Users) > 100 {
return errors.Errorf("Users slice length %d is over 100", len(am.Users))
return errors.Errorf("users slice length %d is over 100", len(am.Users))
}
for _, allowed := range am.Parse {
switch allowed {
case AllowRoleMention:
if len(am.Roles) > 0 {
return errors.New(`Parse has AllowRoleMention and Roles slice is not empty`)
return errors.New(`parse has AllowRoleMention and Roles slice is not empty`)
}
case AllowUserMention:
if len(am.Users) > 0 {
return errors.New(`Parse has AllowUserMention and Users slice is not empty`)
return errors.New(`parse has AllowUserMention and Users slice is not empty`)
}
}
}
@ -88,7 +85,7 @@ func (am AllowedMentions) Verify() error {
// ErrEmptyMessage is returned if either a SendMessageData or an
// ExecuteWebhookData has both an empty Content and no Embed(s).
var ErrEmptyMessage = errors.New("Message is empty")
var ErrEmptyMessage = errors.New("message is empty")
// SendMessageFile represents a file to be uploaded to Discord.
type SendMessageFile struct {
@ -138,7 +135,7 @@ func (data *SendMessageData) WriteMultipart(body *multipart.Writer) error {
// least one of content, embed or file. For a file attachment, the
// Content-Disposition subpart header MUST contain a filename parameter.
func (c *Client) SendMessageComplex(
channelID discord.Snowflake, data SendMessageData) (*discord.Message, error) {
channelID discord.ChannelID, data SendMessageData) (*discord.Message, error) {
if data.Content == "" && data.Embed == nil && len(data.Files) == 0 {
return nil, ErrEmptyMessage
@ -146,13 +143,13 @@ func (c *Client) SendMessageComplex(
if data.AllowedMentions != nil {
if err := data.AllowedMentions.Verify(); err != nil {
return nil, errors.Wrap(err, "AllowedMentions error")
return nil, errors.Wrap(err, "allowedMentions error")
}
}
if data.Embed != nil {
if err := data.Embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error")
return nil, errors.Wrap(err, "embed error")
}
}
@ -207,76 +204,17 @@ func (data *ExecuteWebhookData) WriteMultipart(body *multipart.Writer) error {
return writeMultipart(body, data, data.Files)
}
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
// wait for the message to be delivered and will return the message body. This
// also means the returned message will only be there if wait is true.
func (c *Client) ExecuteWebhook(
webhookID discord.Snowflake,
token string,
wait bool, // if false, then nil returned for *Message.
data ExecuteWebhookData) (*discord.Message, error) {
if data.Content == "" && len(data.Embeds) == 0 && len(data.Files) == 0 {
return nil, ErrEmptyMessage
}
if data.AllowedMentions != nil {
if err := data.AllowedMentions.Verify(); err != nil {
return nil, errors.Wrap(err, "AllowedMentions error")
}
}
for i, embed := range data.Embeds {
if err := embed.Validate(); err != nil {
return nil, errors.Wrap(err, "Embed error at "+strconv.Itoa(i))
}
}
var param = url.Values{}
if wait {
param.Set("wait", "true")
}
var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode()
var msg *discord.Message
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
if err != nil {
return nil, err
}
var body = resp.GetBody()
defer body.Close()
if !wait {
// Since we didn't tell Discord to wait, we have nothing to parse.
return nil, nil
}
return msg, json.DecodeStream(body, &msg)
}
func writeMultipart(body *multipart.Writer, item interface{}, files []SendMessageFile) error {
defer body.Close()
// Encode the JSON body first
w, err := body.CreateFormField("payload_json")
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for JSON")
return errors.Wrap(err, "failed to create bodypart for JSON")
}
if err := json.EncodeStream(w, item); err != nil {
return errors.Wrap(err, "Failed to encode JSON")
return errors.Wrap(err, "failed to encode JSON")
}
for i, file := range files {
@ -284,11 +222,11 @@ func writeMultipart(body *multipart.Writer, item interface{}, files []SendMessag
w, err := body.CreateFormFile("file"+num, file.Name)
if err != nil {
return errors.Wrap(err, "Failed to create bodypart for "+num)
return errors.Wrap(err, "failed to create bodypart for "+num)
}
if _, err := io.Copy(w, file.Reader); err != nil {
return errors.Wrap(err, "Failed to write for file "+num)
return errors.Wrap(err, "failed to write for file "+num)
}
}

View file

@ -34,7 +34,7 @@ func TestMarshalAllowedMentions(t *testing.T) {
t.Run("allow certain user IDs", func(t *testing.T) {
var data = SendMessageData{
AllowedMentions: &AllowedMentions{
Users: []discord.Snowflake{1, 2},
Users: []discord.UserID{1, 2},
},
}
@ -48,7 +48,7 @@ func TestVerifyAllowedMentions(t *testing.T) {
t.Run("invalid", func(t *testing.T) {
var am = AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Users: []discord.Snowflake{69, 420},
Users: []discord.UserID{69, 420},
}
err := am.Verify()
@ -57,27 +57,27 @@ func TestVerifyAllowedMentions(t *testing.T) {
t.Run("users too long", func(t *testing.T) {
var am = AllowedMentions{
Users: make([]discord.Snowflake, 101),
Users: make([]discord.UserID, 101),
}
err := am.Verify()
errMustContain(t, err, "Users slice length 101 is over 100")
errMustContain(t, err, "users slice length 101 is over 100")
})
t.Run("roles too long", func(t *testing.T) {
var am = AllowedMentions{
Roles: make([]discord.Snowflake, 101),
Roles: make([]discord.RoleID, 101),
}
err := am.Verify()
errMustContain(t, err, "Roles slice length 101 is over 100")
errMustContain(t, err, "roles slice length 101 is over 100")
})
t.Run("valid", func(t *testing.T) {
var am = AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Roles: []discord.Snowflake{1337},
Users: []discord.Snowflake{},
Roles: []discord.RoleID{1337},
Users: []discord.UserID{},
}
if err := am.Verify(); err != nil {
@ -125,12 +125,12 @@ func TestSendMessage(t *testing.T) {
Content: "hime arikawa",
AllowedMentions: &AllowedMentions{
Parse: []AllowedMentionType{AllowEveryoneMention, AllowUserMention},
Users: []discord.Snowflake{69, 420},
Users: []discord.UserID{69, 420},
},
}
err := send(data)
errMustContain(t, err, "AllowedMentions error")
errMustContain(t, err, "allowedMentions error")
})
t.Run("invalid embed", func(t *testing.T) {
@ -142,7 +142,7 @@ func TestSendMessage(t *testing.T) {
}
err := send(data)
errMustContain(t, err, "Embed error")
errMustContain(t, err, "embed error")
})
}

View file

@ -12,7 +12,7 @@ var (
)
// User returns a user object for a given user ID.
func (c *Client) User(userID discord.Snowflake) (*discord.User, error) {
func (c *Client) User(userID discord.UserID) (*discord.User, error) {
var u *discord.User
return u, c.RequestJSON(&u, "GET", EndpointUsers+userID.String())
}
@ -40,30 +40,11 @@ func (c *Client) ModifyMe(data ModifySelfData) (*discord.User, error) {
return u, c.RequestJSON(&u, "PATCH", EndpointMe, httputil.WithJSONBody(data))
}
// PrivateChannels returns a list of DM channel objects. For bots, this is no
// longer a supported method of getting recent DMs, and will return an empty
// array.
func (c *Client) PrivateChannels() ([]discord.Channel, error) {
var dms []discord.Channel
return dms, c.RequestJSON(&dms, "GET", EndpointMe+"/channels")
}
// CreatePrivateChannel creates a new DM channel with a user.
func (c *Client) CreatePrivateChannel(recipientID discord.Snowflake) (*discord.Channel, error) {
var param struct {
RecipientID discord.Snowflake `json:"recipient_id"`
}
param.RecipientID = recipientID
var dm *discord.Channel
return dm, c.RequestJSON(&dm, "POST", EndpointMe+"/channels", httputil.WithJSONBody(param))
}
// ChangeOwnNickname only replies with the nickname back, so we're not even
// going to bother.
// ChangeOwnNickname modifies the nickname of the current user in a guild.
//
// Fires a Guild Member Update Gateway event.
func (c *Client) ChangeOwnNickname(
guildID discord.Snowflake, nick string) error {
guildID discord.GuildID, nick string) error {
var param struct {
Nick string `json:"nick"`
@ -78,7 +59,65 @@ func (c *Client) ChangeOwnNickname(
)
}
// shitty SDK, don't care, PR welcomed
// func (c *Client) CreateGroup(tokens []string, nicks map[])
// PrivateChannels returns a list of DM channel objects. For bots, this is no
// longer a supported method of getting recent DMs, and will return an empty
// array.
func (c *Client) PrivateChannels() ([]discord.Channel, error) {
var dms []discord.Channel
return dms, c.RequestJSON(&dms, "GET", EndpointMe+"/channels")
}
// func (c *Client) UserConnections() ([]discord.Connection, error) {}
// CreatePrivateChannel creates a new DM channel with a user.
func (c *Client) CreatePrivateChannel(recipientID discord.UserID) (*discord.Channel, error) {
var param struct {
RecipientID discord.UserID `json:"recipient_id"`
}
param.RecipientID = recipientID
var dm *discord.Channel
return dm, c.RequestJSON(&dm, "POST", EndpointMe+"/channels", httputil.WithJSONBody(param))
}
// UserConnections returns a list of connection objects. Requires the
// connections OAuth2 scope.
func (c *Client) UserConnections() ([]discord.Connection, error) {
var conn []discord.Connection
return conn, c.RequestJSON(&conn, "GET", EndpointMe+"/connections")
}
// SetNote sets a note for the user. This endpoint is undocumented and might
// only work for user accounts.
func (c *Client) SetNote(userID discord.UserID, note string) error {
var body = struct {
Note string `json:"note"`
}{
Note: note,
}
return c.FastRequest(
"PUT", EndpointMe+"/notes/"+userID.String(),
httputil.WithJSONBody(body),
)
}
// SetRelationship sets the relationship type between the current user and the
// given user.
func (c *Client) SetRelationship(userID discord.UserID, t discord.RelationshipType) error {
var body = struct {
Type discord.RelationshipType `json:"type"`
}{
Type: t,
}
return c.FastRequest(
"PUT", EndpointMe+"/relationships/"+userID.String(),
httputil.WithJSONBody(body),
)
}
// DeleteRelationship deletes the relationship between the current user and the
// given user.
func (c *Client) DeleteRelationship(userID discord.UserID) error {
return c.FastRequest("DELETE", EndpointMe+"/relationships/"+userID.String())
}

View file

@ -22,29 +22,20 @@ type CreateWebhookData struct {
//
// Requires the MANAGE_WEBHOOKS permission.
func (c *Client) CreateWebhook(
channelID discord.Snowflake,
name string, avatar discord.Hash) (*discord.Webhook, error) {
var param struct {
Name string `json:"name"`
Avatar discord.Hash `json:"avatar"`
}
param.Name = name
param.Avatar = avatar
channelID discord.ChannelID, data CreateWebhookData) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(
&w, "POST",
EndpointChannels+channelID.String()+"/webhooks",
httputil.WithJSONBody(param),
httputil.WithJSONBody(data),
)
}
// ChannelWebhooks returns the webhooks of the channel with the given ID.
//
// Requires the MANAGE_WEBHOOKS permission.
func (c *Client) ChannelWebhooks(channelID discord.Snowflake) ([]discord.Webhook, error) {
func (c *Client) ChannelWebhooks(channelID discord.ChannelID) ([]discord.Webhook, error) {
var ws []discord.Webhook
return ws, c.RequestJSON(&ws, "GET", EndpointChannels+channelID.String()+"/webhooks")
}
@ -52,26 +43,17 @@ func (c *Client) ChannelWebhooks(channelID discord.Snowflake) ([]discord.Webhook
// GuildWebhooks returns the webhooks of the guild with the given ID.
//
// Requires the MANAGE_WEBHOOKS permission.
func (c *Client) GuildWebhooks(guildID discord.Snowflake) ([]discord.Webhook, error) {
func (c *Client) GuildWebhooks(guildID discord.GuildID) ([]discord.Webhook, error) {
var ws []discord.Webhook
return ws, c.RequestJSON(&ws, "GET", EndpointGuilds+guildID.String()+"/webhooks")
}
// Webhook returns the webhook with the given id.
func (c *Client) Webhook(webhookID discord.Snowflake) (*discord.Webhook, error) {
func (c *Client) Webhook(webhookID discord.WebhookID) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String())
}
// WebhookWithToken is the same as above, except this call does not require
// authentication and returns no user in the webhook object.
func (c *Client) WebhookWithToken(
webhookID discord.Snowflake, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String()+"/"+token)
}
// https://discord.com/developers/docs/resources/webhook#modify-webhook-json-params
type ModifyWebhookData struct {
// Name is the default name of the webhook.
@ -79,14 +61,14 @@ type ModifyWebhookData struct {
// Avatar is the image for the default webhook avatar.
Avatar *Image `json:"avatar,omitempty"`
// ChannelID is the new channel id this webhook should be moved to.
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
ChannelID discord.ChannelID `json:"channel_id,omitempty"`
}
// ModifyWebhook modifies a webhook.
//
// Requires the MANAGE_WEBHOOKS permission.
func (c *Client) ModifyWebhook(
webhookID discord.Snowflake, data ModifyWebhookData) (*discord.Webhook, error) {
webhookID discord.WebhookID, data ModifyWebhookData) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(
@ -96,29 +78,9 @@ func (c *Client) ModifyWebhook(
)
}
// ModifyWebhookWithToken is the same as above, except this call does not
// require authentication, does not accept a channel_id parameter in the body,
// and does not return a user in the webhook object.
func (c *Client) ModifyWebhookWithToken(
webhookID discord.Snowflake, data ModifyWebhookData, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(
&w, "PATCH",
EndpointWebhooks+webhookID.String()+"/"+token,
httputil.WithJSONBody(data),
)
}
// DeleteWebhook deletes a webhook permanently.
//
// Requires the MANAGE_WEBHOOKS permission.
func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error {
func (c *Client) DeleteWebhook(webhookID discord.WebhookID) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String())
}
// DeleteWebhookWithToken is the same as above, except this call does not
// require authentication.
func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token)
}

View file

@ -2,6 +2,8 @@
Not a lot for a Discord bot:
# THIS IS OUTDATED. TODO: UPDATE.
```
# Cold functions, or functions that are called once in runtime:
BenchmarkConstructor-8 150537 7617 ns/op

View file

@ -30,52 +30,50 @@ type ManualParser interface {
ParseContent([]string) error
}
// ArgumentParts implements ManualParseable, in case you want to parse arguments
// ArgumentParts implements ManualParser, in case you want to parse arguments
// manually. It borrows the library's argument parser.
type ArgumentParts struct {
Command string
Arguments []string
}
type ArgumentParts []string
var _ ManualParser = (*ArgumentParts)(nil)
// ParseContent implements ManualParser.
func (r *ArgumentParts) ParseContent(args []string) error {
r.Command = args[0]
if len(args) > 1 {
r.Arguments = args[1:]
}
*r = args
return nil
}
func (r ArgumentParts) Arg(n int) string {
if n < 0 || n >= len(r.Arguments) {
if n < 0 || n >= len(r) {
return ""
}
return r.Arguments[n]
return r[n]
}
func (r ArgumentParts) After(n int) string {
if n < 0 || n >= len(r.Arguments) {
if n < 0 || n > len(r) {
return ""
}
return strings.Join(r.Arguments[n:], " ")
return strings.Join(r[n:], " ")
}
func (r ArgumentParts) String() string {
return r.Command + " " + strings.Join(r.Arguments, " ")
return strings.Join(r, " ")
}
func (r ArgumentParts) Length() int {
return len(r.Arguments)
return len(r)
}
// Usage implements Usager.
func (r ArgumentParts) Usage() string {
return "strings"
}
// CustomParser has a CustomParse method, which would be passed in the full
// message content with the prefix and command trimmed. This is used
// for commands that require more advanced parsing than the default parser.
//
// Keep in mind that this does not trim arguments before it.
type CustomParser interface {
CustomParse(arguments string) error
}
@ -119,7 +117,7 @@ var ShellwordsEscaper = strings.NewReplacer(
var nilV = reflect.Value{}
func newArgument(t reflect.Type, variadic bool) (*Argument, error) {
// Allow array types if varidic is true.
// Allow array types if variadic is true.
if variadic && t.Kind() == reflect.Slice {
t = t.Elem()
}
@ -132,7 +130,7 @@ func newArgument(t reflect.Type, variadic bool) (*Argument, error) {
ptr = true
}
// This shouldn't be varidic.
// This shouldn't be variadic.
if !variadic && typeI.Implements(typeICusP) {
mt, _ := typeI.MethodByName("CustomParse")
@ -142,7 +140,7 @@ func newArgument(t reflect.Type, variadic bool) (*Argument, error) {
}
return &Argument{
String: t.String(),
String: fromUsager(t),
rtype: t,
pointer: ptr,
custom: &mt,
@ -158,7 +156,7 @@ func newArgument(t reflect.Type, variadic bool) (*Argument, error) {
}
return &Argument{
String: t.String(),
String: fromUsager(t),
rtype: t,
pointer: ptr,
manual: &mt,
@ -242,7 +240,7 @@ func newArgument(t reflect.Type, variadic bool) (*Argument, error) {
}
return &Argument{
String: t.String(),
String: fromUsager(t),
rtype: t,
fn: fn,
}, nil
@ -264,12 +262,9 @@ func quickRet(v interface{}, err error, t reflect.Type) (reflect.Value, error) {
func fromUsager(typeI reflect.Type) string {
if typeI.Implements(typeIUsager) {
mt, ok := typeI.MethodByName("Usage")
if !ok {
panic("BUG: type IUsager does not implement Usage")
}
mt, _ := typeI.MethodByName("Usage")
vs := mt.Func.Call([]reflect.Value{reflect.New(typeI.Elem())})
vs := mt.Func.Call([]reflect.Value{reflect.New(typeI).Elem()})
return vs[0].String()
}

View file

@ -51,15 +51,6 @@ func testArgs(t *testing.T, expect interface{}, input string) {
// used for ctx_test.go
type customManualParsed struct {
args []string
}
func (c *customManualParsed) ParseContent(args []string) error {
c.args = args
return nil
}
type customParsed struct {
parsed bool
}

243
bot/command.go Normal file
View file

@ -0,0 +1,243 @@
package bot
import (
"reflect"
)
type command struct {
value reflect.Value // Func
event reflect.Type
isInterface bool
}
func newCommand(value reflect.Value, event reflect.Type) command {
return command{
value: value,
event: event,
isInterface: event.Kind() == reflect.Interface,
}
}
func (c *command) isEvent(t reflect.Type) bool {
return (!c.isInterface && c.event == t) || (c.isInterface && t.Implements(c.event))
}
func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
return callWith(c.value, arg0, argv...)
}
func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(argv))
if v, ok := arg0.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(arg0))
}
callargs = append(callargs, argv...)
return errorReturns(caller.Call(callargs))
}
type caller interface {
call(arg0 interface{}, argv ...reflect.Value) (interface{}, error)
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil
}
// Treat the last return as an error.
return nil, v.(error)
}
// MethodContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type MethodContext struct {
command
method reflect.Method // extend
middlewares []*MiddlewareContext
Description string
// MethodName is the name of the method. This field should NOT be changed.
MethodName string
// Command is the Discord command used to call the method.
Command string // plumb if empty
// Aliases is alternative way to call command in Discord.
Aliases []string
// Hidden if true will not be shown by (*Subcommand).HelpGenerate().
Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
Arguments []Argument
}
func parseMethod(value reflect.Value, method reflect.Method) *MethodContext {
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
return nil
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
return nil
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
return nil
}
}
var command = MethodContext{
command: newCommand(value, methodT.In(0)),
method: method,
MethodName: method.Name,
Variadic: methodT.IsVariadic(),
}
// Only set the command name if it's a MessageCreate handler.
if command.event == typeMessageCreate {
command.Command = lowerFirstLetter(command.method.Name)
}
if numArgs > 1 {
// Event handlers that aren't MessageCreate should not have arguments.
if command.event != typeMessageCreate {
return nil
}
// If the event type is messageCreate:
command.Arguments = make([]Argument, 0, numArgs-1)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
panic("error parsing argument " + t.String() + ": " + err.Error())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
}
return &command
}
func (cctx *MethodContext) addMiddleware(mw *MiddlewareContext) {
// Skip if mismatch type:
if !mw.command.isEvent(cctx.command.event) {
return
}
cctx.middlewares = append(cctx.middlewares, mw)
}
func (cctx *MethodContext) walkMiddlewares(ev reflect.Value) error {
for _, mw := range cctx.middlewares {
_, err := mw.call(ev)
if err != nil {
return err
}
}
return nil
}
func (cctx *MethodContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
}
// SetName sets the command name.
func (cctx *MethodContext) SetName(name string) {
cctx.Command = name
}
type MiddlewareContext struct {
command
}
// ParseMiddleware parses a middleware function. This function panics.
func ParseMiddleware(mw interface{}) *MiddlewareContext {
value := reflect.ValueOf(mw)
methodT := value.Type()
numArgs := methodT.NumIn()
if numArgs != 1 {
panic("Invalid argument signature for " + methodT.String())
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
if numOut > 1 {
panic("Invalid return signature for " + methodT.String())
}
// Check the last return's type if the method returns anything.
if numOut == 1 {
if i := methodT.Out(0); i == nil || !i.Implements(typeIError) {
panic("unexpected return type (not error) for " + methodT.String())
}
}
var middleware = MiddlewareContext{
command: newCommand(value, methodT.In(0)),
}
return &middleware
}

View file

@ -7,11 +7,12 @@ import (
"strings"
"sync"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/bot/extras/shellwords"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
"github.com/pkg/errors"
)
// Prefixer checks a message if it starts with the desired prefix. By default,
@ -94,6 +95,18 @@ type Context struct {
// This is false by default and only applies to MessageCreate.
AllowBot bool
// QuietUnknownCommand, if true, will not make the bot reply with an unknown
// command error into the chat. This will apply to all other subcommands.
// SilentUnknown controls whether or not an ErrUnknownCommand should be
// returned (instead of a silent error).
SilentUnknown struct {
// Command when true will silent only unknown commands. Known
// subcommands with unknown commands will still error out.
Command bool
// Subcommand when true will suppress unknown subcommands.
Subcommand bool
}
// FormatError formats any errors returned by anything, including the method
// commands or the reflect functions. This also includes invalid usage
// errors or unknown command errors. Returning an empty string means
@ -112,6 +125,11 @@ type Context struct {
// MessageCreate events.
ReplyError bool
// EditableCommands when true will also listen for MessageUpdateEvent and
// treat them as newly created messages. This is convenient if you want
// to quickly edit a message and re-execute the command.
EditableCommands bool
// Subcommands contains all the registered subcommands. This is not
// exported, as it shouldn't be used directly.
subcommands []*Subcommand
@ -123,17 +141,18 @@ type Context struct {
// Start quickly starts a bot with the given command. It will prepend "Bot"
// into the token automatically. Refer to example/ for usage.
func Start(token string, cmd interface{},
func Start(
token string, cmd interface{},
opts func(*Context) error) (wait func() error, err error) {
s, err := state.New("Bot " + token)
if err != nil {
return nil, errors.Wrap(err, "Failed to create a dgo session")
return nil, errors.Wrap(err, "failed to create a dgo session")
}
c, err := New(s, cmd)
if err != nil {
return nil, errors.Wrap(err, "Failed to create rfrouter")
return nil, errors.Wrap(err, "failed to create rfrouter")
}
s.Gateway.ErrorLog = func(err error) {
@ -149,7 +168,7 @@ func Start(token string, cmd interface{},
cancel := c.Start()
if err := s.Open(); err != nil {
return nil, errors.Wrap(err, "Failed to connect to Discord")
return nil, errors.Wrap(err, "failed to connect to Discord")
}
return func() error {
@ -163,7 +182,7 @@ func Start(token string, cmd interface{},
// Wait blocks until SIGINT.
func Wait() {
sigs := make(chan os.Signal)
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, os.Interrupt)
<-sigs
}
@ -176,7 +195,7 @@ func Wait() {
// }
//
// cmds := &Commands{}
// c, err := rfrouter.New(session, cmds)
// c, err := bot.New(session, cmds)
//
// The default prefix is "~", which means commands must start with "~" followed
// by the command name in the first argument, else it will be ignored.
@ -204,12 +223,18 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
}
if err := ctx.InitCommands(ctx); err != nil {
return nil, errors.Wrap(err, "Failed to initialize with given cmds")
return nil, errors.Wrap(err, "failed to initialize with given cmds")
}
return ctx, nil
}
// AddIntent adds the given Gateway Intent into the Gateway. This is a
// convenient function that calls Gateway's AddIntent.
func (ctx *Context) AddIntent(i gateway.Intents) {
ctx.Gateway.AddIntent(i)
}
// Subcommands returns the slice of subcommands. To add subcommands, use
// RegisterSubcommand().
func (ctx *Context) Subcommands() []*Subcommand {
@ -217,39 +242,23 @@ func (ctx *Context) Subcommands() []*Subcommand {
return ctx.subcommands
}
// FindCommand finds a command based on the struct and method name. The queried
// FindMethod finds a method based on the struct and method name. The queried
// names will have their flags stripped.
//
// Example
//
// // Find a command from the main context:
// cmd := ctx.FindCommand("", "Method")
// cmd := ctx.FindMethod("", "Method")
// // Find a command from a subcommand:
// cmd = ctx.FindCommand("Starboard", "Reset")
// cmd = ctx.FindMethod("Starboard", "Reset")
//
func (ctx *Context) FindCommand(structname, methodname string) *CommandContext {
if structname == "" {
for _, c := range ctx.Commands {
if c.MethodName == methodname {
return c
}
}
return nil
func (ctx *Context) FindCommand(structName, methodName string) *MethodContext {
if structName == "" {
return ctx.Subcommand.FindCommand(methodName)
}
for _, sub := range ctx.subcommands {
if sub.StructName != structname {
continue
}
for _, c := range sub.Commands {
if c.MethodName == methodname {
return c
}
if sub.StructName == structName {
return sub.FindCommand(methodName)
}
}
return nil
}
@ -257,34 +266,48 @@ func (ctx *Context) FindCommand(structname, methodname string) *CommandContext {
// fails. This is recommended, as subcommands won't change after initializing
// once in runtime, thus fairly harmless after development.
func (ctx *Context) MustRegisterSubcommand(cmd interface{}) *Subcommand {
s, err := ctx.RegisterSubcommand(cmd)
return ctx.MustRegisterSubcommandCustom(cmd, "")
}
// MustRegisterSubcommandCustom works similarly to MustRegisterSubcommand, but
// takes an extra argument for a command name override.
func (ctx *Context) MustRegisterSubcommandCustom(cmd interface{}, name string) *Subcommand {
s, err := ctx.RegisterSubcommandCustom(cmd, name)
if err != nil {
panic(err)
}
return s
}
// RegisterSubcommand registers and adds cmd to the list of subcommands. It will
// also return the resulting Subcommand.
func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) {
return ctx.RegisterSubcommandCustom(cmd, "")
}
// RegisterSubcommand registers and adds cmd to the list of subcommands with a
// custom command name (optional).
func (ctx *Context) RegisterSubcommandCustom(cmd interface{}, name string) (*Subcommand, error) {
s, err := NewSubcommand(cmd)
if err != nil {
return nil, errors.Wrap(err, "Failed to add subcommand")
return nil, errors.Wrap(err, "failed to add subcommand")
}
// Register the subcommand's name.
s.NeedsName()
if name != "" {
s.Command = name
}
if err := s.InitCommands(ctx); err != nil {
return nil, errors.Wrap(err, "Failed to initialize subcommand")
return nil, errors.Wrap(err, "failed to initialize subcommand")
}
// Do a collision check
for _, sub := range ctx.subcommands {
if sub.Command == s.Command {
return nil, errors.New(
"New subcommand has duplicate name: " + s.Command)
return nil, errors.New("new subcommand has duplicate name: " + s.Command)
}
}
@ -292,8 +315,8 @@ func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) {
return s, nil
}
// Start adds itself into the discordgo Session handlers. This needs to be run.
// The returned function is a delete function, which removes itself from the
// Start adds itself into the session handlers. This needs to be run. The
// returned function is a delete function, which removes itself from the
// Session handlers.
func (ctx *Context) Start() func() {
return ctx.State.AddHandler(func(v interface{}) {
@ -317,7 +340,7 @@ func (ctx *Context) Start() func() {
case *ErrInvalidUsage, *ErrUnknownCommand:
// Ignore
default:
ctx.ErrorLogger(errors.Wrap(err, "Command error"))
ctx.ErrorLogger(errors.Wrap(err, "command error"))
}
return
@ -349,63 +372,75 @@ func (ctx *Context) Call(event interface{}) error {
return ctx.callCmd(event)
}
// Help generates one. This function is used more for reference than an actual
// help message. As such, it only uses exported fields or methods.
// Help generates a full Help message. It serves mainly as a reference for
// people to reimplement and change. It doesn't show hidden commands.
func (ctx *Context) Help() string {
return ctx.help(true)
return ctx.HelpGenerate(false)
}
func (ctx *Context) HelpAdmin() string {
return ctx.help(false)
}
func (ctx *Context) help(hideAdmin bool) string {
const indent = " "
var help strings.Builder
// Generate the headers and descriptions
help.WriteString("__Help__")
// HelpGenerate generates a full Help message. It serves mainly as a reference
// for people to reimplement and change. If showHidden is true, then hidden
// subcommands and commands will be shown.
func (ctx *Context) HelpGenerate(showHidden bool) string {
// Generate the header.
buf := strings.Builder{}
buf.WriteString("__Help__")
// Name an
if ctx.Name != "" {
help.WriteString(": " + ctx.Name)
buf.WriteString(": " + ctx.Name)
}
if ctx.Description != "" {
help.WriteString("\n" + indent + ctx.Description)
}
if ctx.Flag.Is(AdminOnly) {
// That's it.
return help.String()
buf.WriteString("\n" + IndentLines(ctx.Description))
}
// Separators
help.WriteString("\n---\n")
buf.WriteString("\n---\n")
// Generate all commands
help.WriteString("__Commands__")
help.WriteString(ctx.Subcommand.Help(indent, hideAdmin))
help.WriteByte('\n')
if help := ctx.Subcommand.Help(); help != "" {
buf.WriteString("__Commands__\n")
buf.WriteString(IndentLines(help))
buf.WriteByte('\n')
}
var subHelp = strings.Builder{}
var subcommands = ctx.Subcommands()
var subhelps = make([]string, 0, len(subcommands))
for _, sub := range subcommands {
if help := sub.Help(indent, hideAdmin); help != "" {
for _, line := range strings.Split(help, "\n") {
subHelp.WriteString(indent)
subHelp.WriteString(line)
subHelp.WriteByte('\n')
}
if sub.Hidden && !showHidden {
continue
}
help := sub.HelpShowHidden(showHidden)
if help == "" {
continue
}
help = IndentLines(help)
var header = "**" + sub.Command + "**"
if sub.Description != "" {
header += ": " + sub.Description
}
subhelps = append(subhelps, header+"\n"+help)
}
if subHelp.Len() > 0 {
help.WriteString("---\n")
help.WriteString("__Subcommands__\n")
help.WriteString(subHelp.String())
if len(subhelps) > 0 {
buf.WriteString("---\n")
buf.WriteString("__Subcommands__\n")
buf.WriteString(IndentLines(strings.Join(subhelps, "\n")))
}
return help.String()
return buf.String()
}
// IndentLine prefixes every line from input with a single-level indentation.
func IndentLines(input string) string {
const indent = " "
var lines = strings.Split(input, "\n")
for i := range lines {
lines[i] = indent + lines[i]
}
return strings.Join(lines, "\n")
}

View file

@ -5,136 +5,107 @@ import (
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
)
// NonFatal is an interface that a method can implement to ignore all errors.
// This works similarly to Break.
type NonFatal interface {
error
IgnoreError() // noop method
}
func onlyFatal(err error) error {
if _, ok := err.(NonFatal); ok {
return nil
}
return err
}
type _Break struct{ error }
// implement NonFatal.
func (_Break) IgnoreError() {}
// Break is a non-fatal error that could be returned from middlewares or
// handlers to stop the chain of execution.
//
// Middlewares are guaranteed to be executed before handlers, but the exact
// order of each are undefined. Main handlers are also guaranteed to be executed
// before all subcommands. If a main middleware cancels, no subcommand
// middlewares will be called.
//
// Break implements the NonFatal interface, which causes an error to be ignored.
var Break NonFatal = _Break{errors.New("break middleware chain, non-fatal")}
func (ctx *Context) filterEventType(evT reflect.Type) []*CommandContext {
var callers []*CommandContext
var middles []*CommandContext
var found bool
find := func(sub *Subcommand) {
for _, cmd := range sub.Events {
// Search only for callers, so skip middlewares.
if cmd.Flag.Is(Middleware) {
continue
}
if cmd.event == evT {
callers = append(callers, cmd)
found = true
}
}
// Only get middlewares if we found handlers for that same event.
if found {
// Search for middlewares with the same type:
for _, mw := range sub.mwMethods {
if mw.event == evT {
middles = append(middles, mw)
}
}
}
}
// Break is a non-fatal error that could be returned from middlewares to stop
// the chain of execution.
var Break = errors.New("break middleware chain, non-fatal")
// filterEventType filters all commands and subcommands into a 2D slice,
// structured so that a Break would only exit out the nested slice.
func (ctx *Context) filterEventType(evT reflect.Type) (callers [][]caller) {
// Find the main context first.
find(ctx.Subcommand)
callers = append(callers, ctx.eventCallers(evT))
for _, sub := range ctx.subcommands {
// Reset found status
found = false
// Find subcommands second.
find(sub)
callers = append(callers, sub.eventCallers(evT))
}
return append(middles, callers...)
return
}
func (ctx *Context) callCmd(ev interface{}) error {
evT := reflect.TypeOf(ev)
func (ctx *Context) callCmd(ev interface{}) (bottomError error) {
evV := reflect.ValueOf(ev)
evT := evV.Type()
var isAdmin *bool // I want to die.
var isGuild *bool
var callers []*CommandContext
var callers [][]caller
// Hit the cache
t, ok := ctx.typeCache.Load(evT)
if ok {
callers = t.([]*CommandContext)
callers = t.([][]caller)
} else {
callers = ctx.filterEventType(evT)
ctx.typeCache.Store(evT, callers)
}
// We can't do the callers[:0] trick here, as it will modify the slice
// inside the sync.Map as well.
var filtered = make([]*CommandContext, 0, len(callers))
for _, subcallers := range callers {
for _, c := range subcallers {
_, err := c.call(evV)
if err != nil {
// Only count as an error if it's not Break.
if err = errNoBreak(err); err != nil {
bottomError = err
}
for _, cmd := range callers {
// Command flags will inherit its parent Subcommand's flags.
if true &&
!(cmd.Flag.Is(AdminOnly) && !ctx.eventIsAdmin(ev, &isAdmin)) &&
!(cmd.Flag.Is(GuildOnly) && !ctx.eventIsGuild(ev, &isGuild)) {
filtered = append(filtered, cmd)
}
}
for _, c := range filtered {
_, err := callWith(c.value, ev)
if err != nil {
if err = onlyFatal(err); err != nil {
ctx.ErrorLogger(err)
// Break the caller loop only for this subcommand.
break
}
return err
}
}
// We call the messages later, since Hidden handlers will go into the Events
// slice, but we don't want to ignore those handlers either.
if evT == typeMessageCreate {
// safe assertion always
err := ctx.callMessageCreate(ev.(*gateway.MessageCreateEvent))
return onlyFatal(err)
var msc *gateway.MessageCreateEvent
// We call the messages later, since we want MessageCreate middlewares to
// run as well.
switch {
case evT == typeMessageCreate:
msc = ev.(*gateway.MessageCreateEvent)
case evT == typeMessageUpdate && ctx.EditableCommands:
up := ev.(*gateway.MessageUpdateEvent)
// Message updates could have empty contents when only their embeds are
// filled. We don't need that here.
if up.Content == "" {
return nil
}
// Query the updated message.
m, err := ctx.Store.Message(up.ChannelID, up.ID)
if err != nil {
// It's probably safe to ignore this.
return nil
}
// Treat the message update as a message create event to avoid breaking
// changes.
msc = &gateway.MessageCreateEvent{Message: *m, Member: up.Member}
// Fill up member, if available.
if m.GuildID.IsValid() && up.Member == nil {
if mem, err := ctx.Store.Member(m.GuildID, m.Author.ID); err == nil {
msc.Member = mem
}
}
// Update the reflect value as well.
evV = reflect.ValueOf(msc)
default:
// Unknown event, return.
return nil
}
return nil
// There's no need for an errNoBreak here, as the method already checked
// for that.
return ctx.callMessageCreate(msc, evV)
}
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent, value reflect.Value) error {
// check if bot
if !ctx.AllowBot && mc.Author.Bot {
return nil
@ -146,7 +117,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
return nil
}
// trim the prefix before splitting, this way multi-words prefices work
// trim the prefix before splitting, this way multi-words prefixes work
content := mc.Content[len(pf):]
if content == "" {
@ -154,111 +125,26 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
}
// parse arguments
parts, err := ctx.ParseArgs(content)
if err != nil {
return errors.Wrap(err, "Failed to parse command")
}
parts, parseErr := ctx.ParseArgs(content)
// We're not checking parse errors yet, as raw arguments may be able to
// ignore it.
if len(parts) == 0 {
return nil // ???
return parseErr
}
var cmd *CommandContext
var sub *Subcommand
// var start int // arg starts from $start
// Check if plumb:
if ctx.plumb {
cmd = ctx.Commands[0]
sub = ctx.Subcommand
// start = 0
// Find the command and subcommand.
arguments, cmd, sub, err := ctx.findCommand(parts)
if err != nil {
return errNoBreak(err)
}
// Arguments slice, which will be sliced away until only arguments are left.
var arguments = parts
// We don't run the subcommand's middlewares here, as the callCmd function
// already handles that.
// If not plumb, search for the command
if cmd == nil {
for _, c := range ctx.Commands {
if c.Command == parts[0] {
cmd = c
sub = ctx.Subcommand
arguments = arguments[1:]
// start = 1
break
}
}
}
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
if cmd == nil {
for _, s := range ctx.subcommands {
if s.Command != parts[0] {
continue
}
// Check if plumb:
if s.plumb {
cmd = s.Commands[0]
sub = s
arguments = arguments[1:]
// start = 1
break
}
// There's no second argument, so we can only look for Plumbed
// subcommands.
if len(parts) < 2 {
continue
}
for _, c := range s.Commands {
if c.Command == parts[1] {
cmd = c
sub = s
arguments = arguments[2:]
break
// start = 2
}
}
if cmd == nil {
if s.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[1],
Parent: parts[0],
ctx: s.Commands,
}
}
break
}
}
if cmd == nil {
if ctx.QuietUnknownCommand {
return nil
}
return &ErrUnknownCommand{
Command: parts[0],
ctx: ctx.Commands,
}
}
// Check for IsAdmin and IsGuild
if cmd.Flag.Is(GuildOnly) && !mc.GuildID.Valid() {
return nil
}
if cmd.Flag.Is(AdminOnly) {
p, err := ctx.State.Permissions(mc.ChannelID, mc.Author.ID)
if err != nil || !p.Has(discord.PermissionAdministrator) {
return nil
}
// Run command middlewares.
if err := cmd.walkMiddlewares(value); err != nil {
return errNoBreak(err)
}
// Start converting
@ -270,7 +156,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
// Here's an edge case: when the handler takes no arguments, we allow that
// anyway, as they might've used the raw content.
if len(cmd.Arguments) < 1 {
if len(cmd.Arguments) == 0 {
goto Call
}
@ -339,7 +225,7 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
vars := make([]reflect.Value, 0, len(arguments))
// Parse the rest with variadic arguments. Go's reflect states that
// varidic parameters will automatically be copied, which is good.
// variadic parameters will automatically be copied, which is good.
for i := 0; len(arguments) > 0; i++ {
v, err := last.fn(arguments[0])
if err != nil {
@ -371,19 +257,31 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
// If the argument wants all arguments in string:
case last.custom != nil:
// Ignore parser errors. This allows custom commands sliced away to
// have erroneous hanging quotes.
parseErr = nil
// Manual string seeking is a must here. This is because the string
// could contain multiple whitespaces, and the parser would not
// count them.
var seekTo = cmd.Command
// If plumbed, then there would only be the subcommand.
if sub.plumb {
// We can't rely on the plumbing behavior.
if sub.plumbed != nil {
seekTo = sub.Command
}
// Seek to the string.
if i := strings.Index(content, seekTo); i > -1 {
var i = strings.Index(content, seekTo)
// Edge case if the subcommand is the same as the command.
if cmd.Command == sub.Command {
// Seek again past the command.
i = strings.Index(content[i+len(seekTo):], seekTo)
}
if i > -1 {
// Seek past the substring.
i += len(seekTo)
content = strings.TrimSpace(content[i:])
}
@ -405,18 +303,14 @@ func (ctx *Context) callMessageCreate(mc *gateway.MessageCreateEvent) error {
argv = append(argv, v)
}
Call:
// Try calling all middlewares first. We don't need to stack middlewares, as
// there will only be one command match.
for _, mw := range sub.mwMethods {
_, err := callWith(mw.value, mc)
if err != nil {
return err
}
// Check for parsing errors after parsing arguments.
if parseErr != nil {
return parseErr
}
Call:
// call the function and parse the error return value
v, err := callWith(cmd.value, mc, argv...)
v, err := cmd.call(value, argv...)
if err != nil {
return err
}
@ -437,91 +331,73 @@ Call:
return err
}
func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
if *is != nil {
return **is
// findCommand filters.
func (ctx *Context) findCommand(parts []string) ([]string, *MethodContext, *Subcommand, error) {
// Main command entrypoint cannot have plumb.
for _, c := range ctx.Commands {
if c.Command == parts[0] {
return parts[1:], c, ctx.Subcommand, nil
}
// Check for alias
for _, alias := range c.Aliases {
if alias == parts[0] {
return parts[1:], c, ctx.Subcommand, nil
}
}
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
var userID = infer.UserID(ev)
if !userID.Valid() {
return false
}
var res bool
p, err := ctx.State.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
res = true
}
*is = &res
return res
}
func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
if *is != nil {
return **is
}
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
c, err := ctx.State.Channel(channelID)
if err != nil {
return false
}
res := c.GuildID.Valid()
*is = &res
return res
}
func callWith(
caller reflect.Value,
ev interface{}, values ...reflect.Value) (interface{}, error) {
var callargs = make([]reflect.Value, 0, 1+len(values))
if v, ok := ev.(reflect.Value); ok {
callargs = append(callargs, v)
} else {
callargs = append(callargs, reflect.ValueOf(ev))
}
callargs = append(callargs, values...)
return errorReturns(caller.Call(callargs))
}
func errorReturns(returns []reflect.Value) (interface{}, error) {
// Handlers may return nothing.
if len(returns) == 0 {
return nil, nil
}
// assume first return is always error, since we checked for this in
// parseCommands.
v := returns[len(returns)-1].Interface()
// If the last return (error) is nil.
if v == nil {
// If we only have 1 returns, that return must be the error. The error
// is nil, so nil is returned.
if len(returns) == 1 {
return nil, nil
// Can't find the command, look for subcommands if len(args) has a 2nd
// entry.
for _, s := range ctx.subcommands {
if s.Command != parts[0] {
continue
}
// Return the first argument as-is. The above returns[-1] check assumes
// 2 return values (T, error), meaning returns[0] is the T value.
return returns[0].Interface(), nil
// Only actually plumb if we actually have a plumbed handler AND
// 1. We only have one command handler OR
// 2. We only have the subcommand name but no command.
if s.plumbed != nil && (len(s.Commands) == 1 || len(parts) <= 2) {
return parts[1:], s.plumbed, s, nil
}
if len(parts) >= 2 {
for _, c := range s.Commands {
if c.Command == parts[1] {
return parts[2:], c, s, nil
}
// Check for aliases
for _, alias := range c.Aliases {
if alias == parts[1] {
return parts[2:], c, s, nil
}
}
}
}
// If unknown command is disabled or the subcommand is hidden:
if ctx.SilentUnknown.Subcommand || s.Hidden {
return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: s,
}
}
// Treat the last return as an error.
return nil, v.(error)
if ctx.SilentUnknown.Command {
return nil, nil, nil, Break
}
return nil, nil, nil, &ErrUnknownCommand{
Parts: parts,
Subcmd: ctx.Subcommand,
}
}
func errNoBreak(err error) error {
if errors.Is(err, Break) {
return nil
}
return err
}

View file

@ -15,22 +15,26 @@ type hasPlumb struct {
NotPlumbed bool
}
func (h *hasPlumb) Setup(sub *Subcommand) {
sub.SetPlumb("Plumber")
}
func (h *hasPlumb) Normal(_ *gateway.MessageCreateEvent) error {
h.NotPlumbed = true
return nil
}
func (h *hasPlumb) PーPlumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
func (h *hasPlumb) Plumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
h.Plumbed = string(c)
return nil
}
func TestSubcommandPlumb(t *testing.T) {
var state = &state.State{
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
c, err := New(state, &testc{})
c, err := New(s, &testc{})
if err != nil {
t.Fatal("Failed to create new context:", err)
}
@ -43,14 +47,10 @@ func TestSubcommandPlumb(t *testing.T) {
t.Fatal("Failed to register hasPlumb:", err)
}
if l := len(c.subcommands[0].Commands); l != 1 {
t.Fatal("Unexpected length for sub.Commands:", l)
}
// Try call exactly what's in the Plumb example:
m := &gateway.MessageCreateEvent{
Message: discord.Message{
Content: "hasPlumb test command",
Content: "hasPlumb",
},
}
@ -61,6 +61,50 @@ func TestSubcommandPlumb(t *testing.T) {
if p.NotPlumbed {
t.Fatal("Normal method called for hasPlumb")
}
}
type onlyPlumb struct {
Ctx *Context
Plumbed string
}
func (h *onlyPlumb) Setup(sub *Subcommand) {
sub.SetPlumb("Plumber")
}
func (h *onlyPlumb) Plumber(_ *gateway.MessageCreateEvent, c RawArguments) error {
h.Plumbed = string(c)
return nil
}
func TestSubcommandOnlyPlumb(t *testing.T) {
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
c, err := New(s, &testc{})
if err != nil {
t.Fatal("Failed to create new context:", err)
}
c.HasPrefix = NewPrefix("")
p := &onlyPlumb{}
_, err = c.RegisterSubcommand(p)
if err != nil {
t.Fatal("Failed to register hasPlumb:", err)
}
// Try call exactly what's in the Plumb example:
m := &gateway.MessageCreateEvent{
Message: discord.Message{
Content: "onlyPlumb test command",
},
}
if err := c.callCmd(m); err != nil {
t.Fatal("Failed to call message:", err)
}
if p.Plumbed != "test command" {
t.Fatal("Unexpected custom argument for plumbed:", p.Plumbed)

View file

@ -12,61 +12,63 @@ import (
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
"github.com/diamondburned/arikawa/utils/handler"
)
type testc struct {
Ctx *Context
Return chan interface{}
Counter uint64
Typed bool
Typed int8
}
func (t *testc) MーBumpCounter(interface{}) {
t.Counter++
func (t *testc) Setup(sub *Subcommand) {
sub.AddMiddleware("*,GetCounter", func(v interface{}) {
t.Counter++
})
sub.AddMiddleware("*", func(*gateway.MessageCreateEvent) {
t.Counter++
})
// stub middleware for testing
sub.AddMiddleware("OnTyping", func(*gateway.TypingStartEvent) {
t.Typed = 2
})
sub.Hide("Hidden")
}
func (t *testc) GetCounter(_ *gateway.MessageCreateEvent) {
func (t *testc) Hidden(*gateway.MessageCreateEvent) {}
func (t *testc) Noop(*gateway.MessageCreateEvent) {}
func (t *testc) GetCounter(*gateway.MessageCreateEvent) {
t.Return <- strconv.FormatUint(t.Counter, 10)
}
func (t *testc) Send(_ *gateway.MessageCreateEvent, args ...string) error {
t.Return <- args
return errors.New("oh no")
}
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *customManualParsed) {
t.Return <- c.args
func (t *testc) Custom(_ *gateway.MessageCreateEvent, c *ArgumentParts) {
t.Return <- []string(*c)
}
func (t *testc) Variadic(_ *gateway.MessageCreateEvent, c ...*customParsed) {
t.Return <- c[len(c)-1]
}
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, s string, c *customManualParsed) {
t.Return <- c.args
func (t *testc) TrailCustom(_ *gateway.MessageCreateEvent, _ string, c ArgumentParts) {
t.Return <- c
}
func (t *testc) Content(_ *gateway.MessageCreateEvent, c RawArguments) {
t.Return <- c
}
func (t *testc) NoArgs(_ *gateway.MessageCreateEvent) error {
func (t *testc) NoArgs(*gateway.MessageCreateEvent) error {
return errors.New("passed")
}
func (t *testc) Noop(_ *gateway.MessageCreateEvent) {
}
func (t *testc) OnTyping(_ *gateway.TypingStartEvent) {
t.Typed = true
func (t *testc) OnTyping(*gateway.TypingStartEvent) {
t.Typed--
}
func TestNewContext(t *testing.T) {
var state = &state.State{
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
c, err := New(state, &testc{})
c, err := New(s, &testc{})
if err != nil {
t.Fatal("Failed to create new context:", err)
}
@ -78,18 +80,22 @@ func TestNewContext(t *testing.T) {
func TestContext(t *testing.T) {
var given = &testc{}
var state = &state.State{
Store: state.NewDefaultStore(nil),
var s = &state.State{
Store: state.NewDefaultStore(nil),
Handler: handler.New(),
}
s, err := NewSubcommand(given)
sub, err := NewSubcommand(given)
if err != nil {
t.Fatal("Failed to create subcommand:", err)
}
var ctx = &Context{
Subcommand: s,
State: state,
Name: "arikawa/bot test",
Description: "Just a test.",
Subcommand: sub,
State: s,
ParseArgs: DefaultArgsParser(),
}
@ -99,11 +105,11 @@ func TestContext(t *testing.T) {
}
if given.Ctx == nil {
t.Fatal("given's Context field is nil")
t.Fatal("given'sub Context field is nil")
}
if given.Ctx.State.Store == nil {
t.Fatal("given's State is nil")
t.Fatal("given'sub State is nil")
}
})
@ -115,11 +121,22 @@ func TestContext(t *testing.T) {
})
t.Run("help", func(t *testing.T) {
if h := ctx.Help(); h == "" {
ctx.MustRegisterSubcommandCustom(&testc{}, "helper")
h := ctx.Help()
if h == "" {
t.Fatal("Empty help?")
}
if h := ctx.HelpAdmin(); h == "" {
t.Fatal("Empty admin help?")
if strings.Contains(h, "hidden") {
t.Fatal("Hidden command shown in help.")
}
if !strings.Contains(h, "arikawa/bot test") {
t.Fatal("Name not found.")
}
if !strings.Contains(h, "Just a test.") {
t.Fatal("Description not found.")
}
})
@ -127,7 +144,7 @@ func TestContext(t *testing.T) {
ctx.HasPrefix = NewPrefix("pls do ")
// This should trigger the middleware first.
if err := expect(ctx, given, "1", "pls do getCounter"); err != nil {
if err := expect(ctx, given, "3", "pls do getCounter"); err != nil {
t.Fatal("Unexpected error:", err)
}
})
@ -139,7 +156,8 @@ func TestContext(t *testing.T) {
t.Fatal("Failed to call with TypingStart:", err)
}
if !given.Typed {
// -1 none ran
if given.Typed != 1 {
t.Fatal("Typed bool is false")
}
})
@ -149,11 +167,11 @@ func TestContext(t *testing.T) {
ctx.HasPrefix = NewPrefix("~")
var (
strings = "hacka doll no. 3"
send = "hacka doll no. 3"
expects = []string{"hacka", "doll", "no.", "3"}
)
if err := expect(ctx, given, expects, "~send "+strings); err.Error() != "oh no" {
if err := expect(ctx, given, expects, "~send "+send); err.Error() != "oh no" {
t.Fatal("Unexpected error:", err)
}
})
@ -187,11 +205,27 @@ func TestContext(t *testing.T) {
t.Run("call command custom trailing manual parser", func(t *testing.T) {
ctx.HasPrefix = NewPrefix("!")
expects := []string{}
expects := ArgumentParts{"arikawa"}
if err := expect(ctx, given, expects, "!trailCustom hime_arikawa"); err != nil {
if err := sendMsg(ctx, given, &expects, "!trailCustom hime arikawa"); err != nil {
t.Fatal("Unexpected call error:", err)
}
if expects.Length() != 1 {
t.Fatal("Unexpected ArgumentParts length.")
}
if expects.After(1)+expects.After(2)+expects.After(-1) != "" {
t.Fatal("Unexpected ArgumentsParts after.")
}
if expects.String() != "arikawa" {
t.Fatal("Unexpected ArgumentsParts string.")
}
if expects.Arg(0) != "arikawa" {
t.Fatal("Unexpected ArgumentParts arg 0")
}
if expects.Arg(1) != "" {
t.Fatal("Unexpected ArgumentParts arg 1")
}
})
testMessage := func(content string) error {
@ -220,7 +254,7 @@ func TestContext(t *testing.T) {
err := testMessage("joe pls no")
if err == nil || !strings.HasPrefix(err.Error(), "Unknown command:") {
if err == nil || !strings.HasPrefix(err.Error(), "unknown command:") {
t.Fatal("unexpected error:", err)
}
})
@ -231,11 +265,7 @@ func TestContext(t *testing.T) {
ctx.HasPrefix = NewPrefix("run ")
sub := &testc{}
_, err := ctx.RegisterSubcommand(sub)
if err != nil {
t.Fatal("Failed to register subcommand:", err)
}
ctx.MustRegisterSubcommand(sub)
if err := testMessage("run testc noop"); err != nil {
t.Fatal("Unexpected error:", err)
@ -251,9 +281,49 @@ func TestContext(t *testing.T) {
t.Fatal("Failed to find subcommand Noop")
}
})
t.Run("register subcommand custom", func(t *testing.T) {
ctx.MustRegisterSubcommandCustom(&testc{}, "arikawa")
})
t.Run("duplicate subcommand", func(t *testing.T) {
_, err := ctx.RegisterSubcommandCustom(&testc{}, "arikawa")
if err := err.Error(); !strings.Contains(err, "duplicate") {
t.Fatal("Unexpected error:", err)
}
})
t.Run("start", func(t *testing.T) {
cancel := ctx.Start()
defer cancel()
ctx.HasPrefix = NewPrefix("!")
given.Return = make(chan interface{})
ctx.Handler.Call(&gateway.MessageCreateEvent{
Message: discord.Message{
Content: "!content hime arikawa best trap",
},
})
if c := (<-given.Return).(RawArguments); c != "hime arikawa best trap" {
t.Fatal("Unexpected content:", c)
}
})
}
func expect(ctx *Context, given *testc, expects interface{}, content string) (call error) {
var v interface{}
if call = sendMsg(ctx, given, &v, content); call != nil {
return
}
if !reflect.DeepEqual(v, expects) {
return fmt.Errorf("returned argument is invalid: %v", v)
}
return nil
}
func sendMsg(ctx *Context, given *testc, into interface{}, content string) (call error) {
// Return channel for testing
ret := make(chan interface{})
given.Return = ret
@ -267,47 +337,46 @@ func expect(ctx *Context, given *testc, expects interface{}, content string) (ca
var callCh = make(chan error)
go func() {
callCh <- ctx.callCmd(m)
callCh <- ctx.Call(m)
}()
select {
case arg := <-ret:
if !reflect.DeepEqual(arg, expects) {
return fmt.Errorf("returned argument is invalid: %v", arg)
}
call = <-callCh
reflect.ValueOf(into).Elem().Set(reflect.ValueOf(arg))
return
case call = <-callCh:
return fmt.Errorf("expected return before error: %w", call)
case <-time.After(time.Second):
return errors.New("Timed out while waiting")
return errors.New("timed out while waiting")
}
}
func BenchmarkConstructor(b *testing.B) {
var state = &state.State{
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
for i := 0; i < b.N; i++ {
_, _ = New(state, &testc{})
_, _ = New(s, &testc{})
}
}
func BenchmarkCall(b *testing.B) {
var given = &testc{}
var state = &state.State{
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
s, _ := NewSubcommand(given)
sub, _ := NewSubcommand(given)
var ctx = &Context{
Subcommand: s,
State: state,
Subcommand: sub,
State: s,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
}
m := &gateway.MessageCreateEvent{
@ -325,16 +394,17 @@ func BenchmarkCall(b *testing.B) {
func BenchmarkHelp(b *testing.B) {
var given = &testc{}
var state = &state.State{
var s = &state.State{
Store: state.NewDefaultStore(nil),
}
s, _ := NewSubcommand(given)
sub, _ := NewSubcommand(given)
var ctx = &Context{
Subcommand: s,
State: state,
Subcommand: sub,
State: s,
HasPrefix: NewPrefix("~"),
ParseArgs: DefaultArgsParser(),
}
b.ResetTimer()

View file

@ -6,33 +6,24 @@ import (
)
type ErrUnknownCommand struct {
Prefix string
Command string
Parent string
// TODO: list available commands?
// Here, as a reminder
ctx []*CommandContext
Parts []string // max len 2
Subcmd *Subcommand
}
func (err *ErrUnknownCommand) Error() string {
if len(err.Parts) > 2 {
err.Parts = err.Parts[:2]
}
return UnknownCommandString(err)
}
var UnknownCommandString = func(err *ErrUnknownCommand) string {
var header = "Unknown command: " + err.Prefix
if err.Parent != "" {
header += err.Parent + " " + err.Command
} else {
header += err.Command
}
return header
return "unknown command: " + strings.Join(err.Parts, " ")
}
var (
ErrTooManyArgs = errors.New("Too many arguments given")
ErrNotEnoughArgs = errors.New("Not enough arguments given")
ErrTooManyArgs = errors.New("too many arguments given")
ErrNotEnoughArgs = errors.New("not enough arguments given")
)
type ErrInvalidUsage struct {
@ -43,7 +34,7 @@ type ErrInvalidUsage struct {
// TODO: usage generator?
// Here, as a reminder
Ctx *CommandContext
Ctx *MethodContext
}
func (err *ErrInvalidUsage) Error() string {
@ -55,12 +46,12 @@ func (err *ErrInvalidUsage) Unwrap() error {
}
var InvalidUsageString = func(err *ErrInvalidUsage) string {
if err.Index == 0 {
return "Invalid usage, error: " + err.Wrap.Error() + "."
if err.Index == 0 && err.Wrap != nil {
return "invalid usage, error: " + err.Wrap.Error() + "."
}
if len(err.Args) == 0 {
return "Missing arguments. Refer to help."
if err.Index == 0 || len(err.Args) == 0 {
return "missing arguments. Refer to help."
}
body := "Invalid usage at " +

56
bot/error_test.go Normal file
View file

@ -0,0 +1,56 @@
package bot
import (
"errors"
"strings"
"testing"
)
func TestInvalidUsage(t *testing.T) {
t.Run("fmt", func(t *testing.T) {
err := ErrInvalidUsage{
Prefix: "!",
Args: []string{"hime", "arikawa"},
Index: 1,
Wrap: errors.New("test error"),
}
str := err.Error()
if !strings.Contains(str, "test error") {
t.Fatal("does not contain 'test error':", str)
}
if !strings.Contains(str, "__arikawa__") {
t.Fatal("Unexpected highlight index:", str)
}
})
t.Run("missing arguments", func(t *testing.T) {
err := ErrInvalidUsage{}
str := err.Error()
if str != "missing arguments. Refer to help." {
t.Fatal("Unexpected error:", str)
}
})
t.Run("no index", func(t *testing.T) {
err := ErrInvalidUsage{Wrap: errors.New("astolfo")}
str := err.Error()
if str != "invalid usage, error: astolfo." {
t.Fatal("Unexpected error:", str)
}
})
t.Run("unwrap", func(t *testing.T) {
var err = errors.New("hackadoll no. 3")
var wrap = &ErrInvalidUsage{
Wrap: err,
}
if !errors.Is(wrap, err) {
t.Fatal("Failed to unwrap, errors mismatch.")
}
})
}

View file

@ -11,11 +11,11 @@ import (
var (
EmojiRegex = regexp.MustCompile(`<(a?):(.+?):(\d+)>`)
ErrInvalidEmoji = errors.New("Invalid emoji")
ErrInvalidEmoji = errors.New("invalid emoji")
)
type Emoji struct {
ID discord.Snowflake
ID discord.EmojiID
Name string
Custom bool
@ -83,7 +83,7 @@ func (e *Emoji) Parse(arg string) error {
e.Custom = true
e.Animated = matches[1] == "a"
e.Name = matches[2]
e.ID = id
e.ID = discord.EmojiID(id)
return nil
}

View file

@ -4,7 +4,6 @@ import (
"bytes"
"flag"
"io/ioutil"
"strings"
)
var FlagName = "command"
@ -30,41 +29,21 @@ func (fs *FlagSet) Usage() string {
return buf.String()
}
type Flag struct {
command string
arguments []string
}
type Flag []string
func (f *Flag) ParseContent(arguments []string) error {
// trim the command out
f.command, f.arguments = arguments[0], arguments[1:]
*f = arguments
return nil
}
func (f *Flag) Usage() string {
return "[flags] arguments..."
func (f Flag) Usage() string {
return "[flags] arguments"
}
func (f *Flag) Command() string {
return f.command
func (f Flag) Args() []string {
return f
}
func (f *Flag) Args() []string {
return f.arguments
}
func (f *Flag) Arg(n int) string {
if n < 0 || n >= len(f.arguments) {
return ""
}
return f.arguments[n]
}
func (f *Flag) String() string {
return strings.Join(f.arguments, " ")
}
func (f *Flag) With(fs *flag.FlagSet) error {
return fs.Parse(f.arguments)
func (f Flag) With(fs *flag.FlagSet) error {
return fs.Parse(f)
}

View file

@ -28,30 +28,14 @@ func TestFlagSet(t *testing.T) {
func TestFlag(t *testing.T) {
f := Flag{}
if err := f.ParseContent([]string{"gc", "--now", "1m4s"}); err != nil {
if err := f.ParseContent([]string{"--now", "1m4s"}); err != nil {
t.Fatal("Failed to parse:", err)
}
if f.Command() != "gc" {
t.Fatal("Unexpected command:", f.Command())
}
if args := f.Args(); !reflect.DeepEqual(args, []string{"--now", "1m4s"}) {
t.Fatal("Unexpected arguments:", args)
}
if arg := f.Arg(1200); arg != "" {
t.Fatal("Unexpected argument at 1200th:", arg)
}
if arg := f.Arg(0); arg != "--now" {
t.Fatal("Unexpected argument at 1st:", arg)
}
if s := f.String(); s != "--now 1m4s" {
t.Fatal("Unexpected string:", s)
}
fs := NewFlagSet()
var now bool

View file

@ -12,20 +12,20 @@ import (
// canary. matches canary MessageURL
// 3 `(\d+)` for guild ID, channel ID and message ID
var Regex = regexp.MustCompile(
`https://(|ptb\.|canary\.)discordapp\.com/channels/(\d+)/(\d+)/(\d+)`,
`https://(ptb\.|canary\.)?discord(?:app)?\.com/channels/(\d+)/(\d+)/(\d+)`,
)
// MessageURL contains info from a MessageURL
type MessageURL struct {
GuildID discord.Snowflake
ChannelID discord.Snowflake
MessageID discord.Snowflake
GuildID discord.GuildID
ChannelID discord.ChannelID
MessageID discord.MessageID
}
func (url *MessageURL) Parse(arg string) error {
u := ParseMessageURL(arg)
if u == nil {
return errors.New("Invalid MessageURL format.")
return errors.New("invalid MessageURL format")
}
*url = *u
return nil
@ -55,8 +55,8 @@ func ParseMessageURL(url string) *MessageURL {
}
return &MessageURL{
GuildID: gID,
ChannelID: cID,
MessageID: mID,
GuildID: discord.GuildID(gID),
ChannelID: discord.ChannelID(cID),
MessageID: discord.MessageID(mID),
}
}

View file

@ -15,7 +15,7 @@ var (
//
type ChannelMention discord.Snowflake
type ChannelMention discord.ChannelID
func (m *ChannelMention) Parse(arg string) error {
return grabFirst(ChannelRegex, "channel mention", arg, (*discord.Snowflake)(m))
@ -25,8 +25,8 @@ func (m *ChannelMention) Usage() string {
return "#channel"
}
func (m *ChannelMention) ID() discord.Snowflake {
return discord.Snowflake(*m)
func (m *ChannelMention) ID() discord.ChannelID {
return discord.ChannelID(*m)
}
func (m *ChannelMention) Mention() string {
@ -35,7 +35,7 @@ func (m *ChannelMention) Mention() string {
//
type UserMention discord.Snowflake
type UserMention discord.UserID
func (m *UserMention) Parse(arg string) error {
return grabFirst(UserRegex, "user mention", arg, (*discord.Snowflake)(m))
@ -45,8 +45,8 @@ func (m *UserMention) Usage() string {
return "@user"
}
func (m *UserMention) ID() discord.Snowflake {
return discord.Snowflake(*m)
func (m *UserMention) ID() discord.UserID {
return discord.UserID(*m)
}
func (m *UserMention) Mention() string {
@ -55,7 +55,7 @@ func (m *UserMention) Mention() string {
//
type RoleMention discord.Snowflake
type RoleMention discord.RoleID
func (m *RoleMention) Parse(arg string) error {
return grabFirst(RoleRegex, "role mention", arg, (*discord.Snowflake)(m))
@ -65,8 +65,8 @@ func (m *RoleMention) Usage() string {
return "@role"
}
func (m *RoleMention) ID() discord.Snowflake {
return discord.Snowflake(*m)
func (m *RoleMention) ID() discord.RoleID {
return discord.RoleID(*m)
}
func (m *RoleMention) Mention() string {
@ -78,12 +78,12 @@ func (m *RoleMention) Mention() string {
func grabFirst(reg *regexp.Regexp, item, input string, output *discord.Snowflake) error {
matches := reg.FindStringSubmatch(input)
if len(matches) < 2 {
return errors.New("Invalid " + item)
return errors.New("invalid " + item)
}
id, err := discord.ParseSnowflake(matches[1])
if err != nil {
return errors.New("Invalid " + item)
return errors.New("invalid " + item)
}
*output = id

View file

@ -6,40 +6,56 @@ import (
"github.com/diamondburned/arikawa/discord"
)
func TestMention(t *testing.T) {
var (
c ChannelMention
u UserMention
r RoleMention
)
func TestChannelMention(t *testing.T) {
test := new(ChannelMention)
str := "<#123123>"
var id discord.ChannelID = 123123
type mention interface {
Parse(arg string) error
ID() discord.Snowflake
Mention() string
if err := test.Parse(str); err != nil {
t.Fatal("Expected", id, "error:", err)
}
var tests = []struct {
mention
str string
id discord.Snowflake
}{
{&c, "<#123123>", 123123},
{&r, "<@&23321>", 23321},
{&u, "<@123123>", 123123},
if actualID := test.ID(); actualID != id {
t.Fatal("Expected", id, "got", id)
}
for _, test := range tests {
if err := test.Parse(test.str); err != nil {
t.Fatal("Expected", test.id, "error:", err)
}
if id := test.ID(); id != test.id {
t.Fatal("Expected", test.id, "got", id)
}
if mention := test.Mention(); mention != test.str {
t.Fatal("Expected", test.str, "got", mention)
}
if mention := test.Mention(); mention != str {
t.Fatal("Expected", str, "got", mention)
}
}
func TestUserMention(t *testing.T) {
test := new(UserMention)
str := "<@123123>"
var id discord.UserID = 123123
if err := test.Parse(str); err != nil {
t.Fatal("Expected", id, "error:", err)
}
if actualID := test.ID(); actualID != id {
t.Fatal("Expected", id, "got", id)
}
if mention := test.Mention(); mention != str {
t.Fatal("Expected", str, "got", mention)
}
}
func TestRoleMention(t *testing.T) {
test := new(RoleMention)
str := "<@&123123>"
var id discord.RoleID = 123123
if err := test.Parse(str); err != nil {
t.Fatal("Expected", id, "error:", err)
}
if actualID := test.ID(); actualID != id {
t.Fatal("Expected", id, "got", id)
}
if mention := test.Mention(); mention != str {
t.Fatal("Expected", str, "got", mention)
}
}

View file

@ -13,19 +13,21 @@ import (
// ChannelID looks for fields with name ChannelID, Channel, or in some special
// cases, ID.
func ChannelID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "Channel")
func ChannelID(event interface{}) discord.ChannelID {
return discord.ChannelID(reflectID(reflect.ValueOf(event), "Channel"))
}
// GuildID looks for fields with name GuildID, Guild, or in some special cases,
// ID.
func GuildID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "Guild")
func GuildID(event interface{}) discord.GuildID {
return discord.GuildID(reflectID(reflect.ValueOf(event), "Guild"))
}
// UserID looks for fields with name UserID, User, or in some special cases, ID.
func UserID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "User")
func UserID(event interface{}) discord.UserID {
// This may have a very fatal bug of accidentally mistaking another User's
// ID. It also probably wouldn't work with things like RecipientID.
return discord.UserID(reflectID(reflect.ValueOf(event), "User"))
}
func reflectID(v reflect.Value, thing string) discord.Snowflake {
@ -62,21 +64,203 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
switch fType.Kind() {
case reflect.Struct:
if chID := reflectID(v.Field(i), thing); chID.Valid() {
if chID := reflectID(v.Field(i), thing); chID.IsValid() {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
case reflect.Uint64:
switch {
case false,
// Contains works with "LastMessageID" and such.
strings.Contains(field.Name, thing+"ID"),
// Special case where the struct name has Channel in it.
field.Name == "ID" && strings.Contains(t.Name(), thing):
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
return discord.Snowflake(v.Field(i).Uint())
}
}
}
return 0
}
/*
var reflectCache sync.Map
type cacheKey struct {
t reflect.Type
f string
}
func getID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
return reflectID(thing, v, t)
}
type reflector struct {
steps []step
thing string
thingID string
}
type step struct {
field int
ptr bool
rec []step
}
func reflectID(thing string, v reflect.Value, t reflect.Type) discord.Snowflake {
r := &reflector{thing: thing}
// copy original type
key := r.thing + t.String()
// check the cache
if instructions, ok := reflectCache.Load(key); ok {
if instructions == nil {
return 0
}
return applyInstructions(v, instructions.([]step))
}
r.thingID = r.thing + "ID"
r.steps = make([]step, 0, 1)
id := r._id(v, t)
if r.steps != nil {
reflectCache.Store(key, r.instructions())
}
return id
}
func applyInstructions(v reflect.Value, instructions []step) discord.Snowflake {
// Use a type here to detect recursion:
// var originalT = v.Type()
var laststep reflect.Value
log.Println(v.Type(), instructions)
for i, step := range instructions {
if !v.IsValid() {
return 0
}
if i > 0 && step.ptr {
v = v.Elem()
}
if !v.IsValid() {
// is this the bottom of the instructions?
if i == len(instructions)-1 && step.rec != nil {
for _, ins := range step.rec {
var value = laststep.Field(ins.field)
if ins.ptr {
value = value.Elem()
}
if id := applyInstructions(value, instructions); id.IsValid() {
return id
}
}
}
return 0
}
laststep = v
v = laststep.Field(step.field)
}
return discord.Snowflake(v.Int())
}
func (r *reflector) instructions() []step {
if len(r.steps) == 0 {
return nil
}
var instructions = make([]step, len(r.steps))
for i := 0; i < len(instructions); i++ {
instructions[i] = r.steps[len(r.steps)-i-1]
}
// instructions := r.steps
return instructions
}
func (r *reflector) step(s step) {
r.steps = append(r.steps, s)
}
func (r *reflector) _id(v reflect.Value, t reflect.Type) (chID discord.Snowflake) {
numFields := t.NumField()
var ptr bool
var ins = step{field: -1}
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
value := v.Field(i)
ptr = false
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
value = value.Elem()
ptr = true
}
// does laststep have the same field type?
if fType == t {
ins.rec = append(ins.rec, step{field: i, ptr: ptr})
}
if !value.IsValid() {
continue
}
// If we've already found the field:
if ins.field > 0 {
continue
}
switch fType.Kind() {
case reflect.Struct:
if chID = r._id(value, fType); chID.IsValid() {
ins.field = i
ins.ptr = ptr
}
case reflect.Int64:
switch {
case false,
// Contains works with "LastMessageID" and such.
strings.Contains(field.Name, r.thingID),
// Special case where the struct name has Channel in it.
field.Name == "ID" && strings.Contains(t.Name(), r.thing):
ins.field = i
ins.ptr = ptr
chID = discord.Snowflake(value.Int())
}
}
}
// If we've found the field:
r.step(ins)
return
}
*/

View file

@ -7,7 +7,7 @@ import (
)
type hasID struct {
ChannelID discord.Snowflake
ChannelID discord.ChannelID
}
type embedsID struct {
@ -16,7 +16,7 @@ type embedsID struct {
}
type hasChannelInName struct {
ID discord.Snowflake
ID discord.ChannelID
}
func TestReflectChannelID(t *testing.T) {
@ -51,15 +51,15 @@ func TestReflectChannelID(t *testing.T) {
})
}
var id discord.Snowflake
func BenchmarkReflectChannelID_1Level(b *testing.B) {
var s = &hasID{
ChannelID: 69420,
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
if id := ChannelID(s); id != s.ChannelID {
b.Fatal("Unexpected ChannelID:", id)
}
}
}
@ -80,6 +80,8 @@ func BenchmarkReflectChannelID_5Level(b *testing.B) {
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
if id := ChannelID(s); id != 69420 {
b.Fatal("Unexpected ChannelID:", id)
}
}
}

View file

@ -0,0 +1,49 @@
package middlewares
import (
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
)
func AdminOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
var channelID = infer.ChannelID(ev)
if !channelID.IsValid() {
return bot.Break
}
var userID = infer.UserID(ev)
if !userID.IsValid() {
return bot.Break
}
p, err := ctx.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
return nil
}
return bot.Break
}
}
func GuildOnly(ctx *bot.Context) func(interface{}) error {
return func(ev interface{}) error {
// Try and infer the GuildID.
if guildID := infer.GuildID(ev); guildID.IsValid() {
return nil
}
var channelID = infer.ChannelID(ev)
if !channelID.IsValid() {
return bot.Break
}
c, err := ctx.Channel(channelID)
if err != nil || !c.GuildID.IsValid() {
return bot.Break
}
return nil
}
}

View file

@ -0,0 +1,194 @@
package middlewares
import (
"errors"
"testing"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
)
func TestAdminOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = AdminOnly(ctx)
t.Run("allow message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 1337,
Author: discord.User{ID: 69420},
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 2,
ChannelID: 1337,
Author: discord.User{ID: 1337},
},
}
expectBreak(t, middleware(msg))
var pin = &gateway.ChannelPinsUpdateEvent{
ChannelID: 120,
}
expectBreak(t, middleware(pin))
var tpg = &gateway.TypingStartEvent{}
expectBreak(t, middleware(tpg))
})
}
func TestGuildOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = GuildOnly(ctx)
t.Run("allow message with GuildID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
expectNil(t, middleware(msg))
})
t.Run("allow message with ChannelID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
ChannelID: 69420,
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 12,
},
}
expectBreak(t, middleware(msg))
var msg2 = &gateway.MessageCreateEvent{}
expectBreak(t, middleware(msg2))
})
}
func expectNil(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal("Unexpected error:", err)
}
}
func expectBreak(t *testing.T, err error) {
t.Helper()
if errors.Is(err, bot.Break) {
return
}
if err != nil {
t.Fatal("Unexpected error:", err)
}
t.Fatal("Expected error, got nothing.")
}
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkGuildOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = GuildOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkAdminOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = AdminOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 1337,
Author: discord.User{ID: 69420},
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
type mockStore struct {
state.NoopStore
}
func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
return &discord.Guild{
ID: id,
Roles: []discord.Role{{
ID: 69420,
Permissions: discord.PermissionAdministrator,
}},
}, nil
}
func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) {
return &discord.Member{
User: discord.User{ID: userID},
RoleIDs: []discord.RoleID{discord.RoleID(userID)},
}, nil
}
// Channel returns a channel with a guildID for #69420.
func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
if id == 69420 {
return &discord.Channel{
ID: id,
GuildID: 1337,
}, nil
}
return &discord.Channel{
ID: id,
}, nil
}

View file

@ -1,19 +1,58 @@
package shellwords
import (
"errors"
"fmt"
"strings"
)
// WordOffset is the offset from the position cursor to print on the error.
const WordOffset = 7
var escaper = strings.NewReplacer(
"`", "\\`",
"@", "\\@",
"\\", "\\\\",
)
type ErrParse struct {
Line string
Position string
Position int
Words string // joined
}
func (e ErrParse) Error() string {
// Magic number 5.
var a = max(0, e.Position-WordOffset)
var b = min(len(e.Words), e.Position+WordOffset)
var word = e.Words[a:b]
var uidx = e.Position - a
errstr := strings.Builder{}
errstr.WriteString("Unexpected quote or escape")
// Do a bound check.
if uidx+1 > len(word) {
// Invalid.
errstr.WriteString(".")
return errstr.String()
}
// Write the pre-underline part.
fmt.Fprintf(
&errstr, ": %s__%s__",
escaper.Replace(word[:uidx]),
escaper.Replace(string(word[uidx:])),
)
return errstr.String()
}
// Parse parses the given text to a slice of words.
func Parse(line string) ([]string, error) {
args := []string{}
buf := ""
var args []string
var escaped, doubleQuoted, singleQuoted bool
backtick := ""
var buf strings.Builder
buf.Grow(len(line))
got := false
cursor := 0
@ -22,14 +61,14 @@ func Parse(line string) ([]string, error) {
for _, r := range runes {
if escaped {
buf += string(r)
buf.WriteRune(r)
escaped = false
continue
}
if r == '\\' {
if singleQuoted {
buf += string(r)
buf.WriteRune(r)
} else {
escaped = true
}
@ -39,12 +78,11 @@ func Parse(line string) ([]string, error) {
if isSpace(r) {
switch {
case singleQuoted, doubleQuoted:
buf += string(r)
backtick += string(r)
buf.WriteRune(r)
case got:
cursor += len(buf)
args = append(args, buf)
buf = ""
cursor += buf.Len()
args = append(args, buf.String())
buf.Reset()
got = false
}
continue
@ -59,42 +97,30 @@ func Parse(line string) ([]string, error) {
doubleQuoted = !doubleQuoted
continue
}
case '\'':
case '\'', '`':
if !doubleQuoted {
if singleQuoted {
got = true
}
singleQuoted = !singleQuoted
continue
}
}
got = true
buf += string(r)
buf.WriteRune(r)
}
if got {
args = append(args, buf)
args = append(args, buf.String())
}
if escaped || singleQuoted || doubleQuoted {
// the number of characters to highlight
var (
pos = cursor + 5
start = string(runes[max(cursor-100, 0) : pos-1])
end = string(runes[pos+1 : min(cursor+100, len(runes))])
part = ""
)
for i := pos - 1; i >= 0 && i < len(runes) && i < pos+2; i++ {
if runes[i] == '\\' {
part += "\\"
}
part += string(runes[i])
return args, &ErrParse{
Position: cursor + buf.Len(),
Words: strings.Join(args, " "),
}
return nil, errors.New(
"Unexpected quote or escape: " + start + "__" + part + "__" + end)
}
return args, nil
@ -102,7 +128,7 @@ func Parse(line string) ([]string, error) {
func isSpace(r rune) bool {
switch r {
case ' ', '\t', '\r', '\n':
case ' ', '\t', '\r', '\n', ' ':
return true
}
return false

View file

@ -0,0 +1,64 @@
package shellwords
import (
"reflect"
"testing"
)
type wordsTest struct {
line string
args []string
doErr bool
}
func TestParse(t *testing.T) {
var tests = []wordsTest{
{
`this is a "test"`,
[]string{"this", "is", "a", "test"},
false,
},
{
`hanging "quote`,
[]string{"hanging", "quote"},
true,
},
{
`Hello, 世界`,
[]string{"Hello,", "世界"},
false,
},
{
"this is `inline code`",
[]string{"this", "is", "inline code"},
false,
},
{
"how about a ```go\npackage main\n```\ngo code?",
[]string{"how", "about", "a", "go\npackage main\n", "go", "code?"},
false,
},
{
"this should not crash `",
[]string{"this", "should", "not", "crash"},
true,
},
{
"this should not crash '",
[]string{"this", "should", "not", "crash"},
true,
},
}
for _, test := range tests {
w, err := Parse(test.line)
if err != nil && !test.doErr {
t.Errorf("Error at %q: %v", test.line, err)
continue
}
if !reflect.DeepEqual(w, test.args) {
t.Errorf("Inequality:\n%#v !=\n%#v", w, test.args)
}
}
}

View file

@ -1,107 +0,0 @@
package bot
import "strings"
type NameFlag uint64
var FlagSeparator = 'ー'
const None NameFlag = 0
// !!!
//
// These flags are applied to all events, if possible. The defined behavior
// is to search for "ChannelID" fields or "ID" fields in structs with
// "Channel" in its name. It doesn't handle individual events, as such, will
// not be able to guarantee it will always work. Refer to package infer.
// R - Raw, which tells the library to use the method name as-is (flags will
// still be stripped). For example, if a method is called Reset its
// command will also be Reset, without being all lower-cased.
const Raw NameFlag = 1 << 1
// A - AdminOnly, which tells the library to only run the Subcommand/method
// if the user is admin or not. This will automatically add GuildOnly as
// well.
const AdminOnly NameFlag = 1 << 2
// G - GuildOnly, which tells the library to only run the Subcommand/method
// if the user is inside a guild.
const GuildOnly NameFlag = 1 << 3
// M - Middleware, which tells the library that the method is a middleware.
// The method will be executed anytime a method of the same struct is
// matched.
//
// Using this flag inside the subcommand will drop all methods (this is an
// undefined behavior/UB).
const Middleware NameFlag = 1 << 4
// H - Hidden/Handler, which tells the router to not add this into the list
// of commands, hiding it from Help. Handlers that are hidden will not have
// any arguments parsed. It will be treated as an Event.
const Hidden NameFlag = 1 << 5
// P - Plumb, which tells the router to call only this handler with all the
// arguments (except the prefix string). If plumb is used, only this method
// will be called for the given struct, though all other events as well as
// methods with the H (Hidden/Handler) flag.
//
// This is different from using H (Hidden/Handler), as handlers are called
// regardless of command prefixes. Plumb methods are only called once, and
// no other methods will be called for that struct. That said, a Plumb
// method would still go into Commands, but only itself will be there.
//
// Note that if there's a Plumb method in the main commands, then none of
// the subcommands would be called. This is an unintended but expected side
// effect.
//
// Example
//
// A use for this would be subcommands that don't need a second command, or
// if the main struct manually handles command switching. This example
// demonstrates the second use-case:
//
// func (s *Sub) PーMain(
// c *gateway.MessageCreateGateway, c *Content) error {
//
// // Input: !sub this is a command
// // Output: this is a command
//
// log.Println(c.String())
// return nil
// }
//
const Plumb NameFlag = 1 << 6
func ParseFlag(name string) (NameFlag, string) {
parts := strings.SplitN(name, string(FlagSeparator), 2)
if len(parts) != 2 {
return 0, name
}
var f NameFlag
for _, r := range parts[0] {
switch r {
case 'R':
f |= Raw
case 'A':
f |= AdminOnly | GuildOnly
case 'G':
f |= GuildOnly
case 'M':
f |= Middleware
case 'H':
f |= Hidden
case 'P':
f |= Plumb
}
}
return f, parts[1]
}
func (f NameFlag) Is(flag NameFlag) bool {
return f&flag != 0
}

View file

@ -1,26 +0,0 @@
package bot
import "testing"
func TestNameFlag(t *testing.T) {
type entry struct {
Name string
Expect NameFlag
String string
}
var entries = []entry{{
Name: "AーEcho",
Expect: AdminOnly,
}, {
Name: "RAーGC",
Expect: Raw | AdminOnly,
}}
for _, entry := range entries {
var f, _ = ParseFlag(entry.Name)
if !f.Is(entry.Expect) {
t.Fatalf("unexpected expectation for %s: %v", entry.Name, f)
}
}
}

View file

@ -4,34 +4,30 @@ import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/gateway"
)
var (
typeMessageCreate = reflect.TypeOf((*gateway.MessageCreateEvent)(nil))
typeString = reflect.TypeOf("")
typeEmbed = reflect.TypeOf((*discord.Embed)(nil))
typeSend = reflect.TypeOf((*api.SendMessageData)(nil))
typeSubcmd = reflect.TypeOf((*Subcommand)(nil))
typeMessageUpdate = reflect.TypeOf((*gateway.MessageUpdateEvent)(nil))
typeIError = reflect.TypeOf((*error)(nil)).Elem()
typeIManP = reflect.TypeOf((*ManualParser)(nil)).Elem()
typeICusP = reflect.TypeOf((*CustomParser)(nil)).Elem()
typeIParser = reflect.TypeOf((*Parser)(nil)).Elem()
typeIUsager = reflect.TypeOf((*Usager)(nil)).Elem()
typeSetupFn = func() reflect.Type {
method, _ := reflect.TypeOf((*CanSetup)(nil)).
Elem().
MethodByName("Setup")
return method.Type
}()
typeSetupFn = methodType((*CanSetup)(nil), "Setup")
)
func methodType(iface interface{}, name string) reflect.Type {
method, _ := reflect.TypeOf(iface).
Elem().
MethodByName(name)
return method.Type
}
// HelpUnderline formats command arguments with an underline, similar to
// manpages.
var HelpUnderline = true
@ -62,38 +58,34 @@ func underline(word string) string {
// func(<AnyEvent>)
//
type Subcommand struct {
// Description is a string that's appended after the subcommand name in
// (*Context).Help().
Description string
// Hidden if true will not be shown by (*Context).Help(). It will
// also cause unknown command errors to be suppressed.
Hidden bool
// Raw struct name, including the flag (only filled for actual subcommands,
// will be empty for Context):
StructName string
// Parsed command name:
Command string
// struct flags
Flag NameFlag
// SanitizeMessage is executed on the message content if the method returns
// a string content or a SendMessageData.
SanitizeMessage func(content string) string
// QuietUnknownCommand, if true, will not make the bot reply with an unknown
// command error into the chat. If this is set in Context, it will apply to
// all other subcommands.
QuietUnknownCommand bool
// Commands can actually return either a string, an embed, or a
// SendMessageData, with error as the second argument.
// All registered command contexts:
Commands []*CommandContext
Events []*CommandContext
// All registered method contexts:
Events []*MethodContext
Commands []*MethodContext
plumbed *MethodContext
// Middleware command contexts:
mwMethods []*CommandContext
// Plumb nameflag, use Commands[0] if true.
plumb bool
// Global middlewares.
globalmws []*MiddlewareContext
// Directly to struct
cmdValue reflect.Value
@ -103,34 +95,10 @@ type Subcommand struct {
ptrValue reflect.Value
ptrType reflect.Type
// command interface as reference
helper func() string
command interface{}
}
// CommandContext is an internal struct containing fields to make this library
// work. As such, they're all unexported. Description, however, is exported for
// editing, and may be used to generate more informative help messages.
type CommandContext struct {
Description string
Flag NameFlag
MethodName string
Command string // empty if Plumb
// Hidden is true if the method has a hidden nameflag.
Hidden bool
// Variadic is true if the function is a variadic one or if the last
// argument accepts multiple strings.
Variadic bool
value reflect.Value // Func
event reflect.Type // gateway.*Event
method reflect.Method
Arguments []Argument
}
// CanSetup is used for subcommands to change variables, such as Description.
// This method will be triggered when InitCommands is called, which is during
// New for Context and during RegisterSubcommand for subcommands.
@ -139,17 +107,12 @@ type CanSetup interface {
Setup(*Subcommand)
}
func (cctx *CommandContext) Usage() []string {
if len(cctx.Arguments) == 0 {
return nil
}
var arguments = make([]string, len(cctx.Arguments))
for i, arg := range cctx.Arguments {
arguments[i] = arg.String
}
return arguments
// CanHelp is an interface that subcommands can implement to return its own help
// message. Those messages will automatically be indented into suitable sections
// by the default Help() implementation. Unlike Usager or CanSetup, the Help()
// method will be called every time it's needed.
type CanHelp interface {
Help() string
}
// NewSubcommand is used to make a new subcommand. You usually wouldn't call
@ -163,11 +126,11 @@ func NewSubcommand(cmd interface{}) (*Subcommand, error) {
}
if err := sub.reflectCommands(); err != nil {
return nil, errors.Wrap(err, "Failed to reflect commands")
return nil, errors.Wrap(err, "failed to reflect commands")
}
if err := sub.parseCommands(); err != nil {
return nil, errors.Wrap(err, "Failed to parse commands")
return nil, errors.Wrap(err, "failed to parse commands")
}
return &sub, nil
@ -177,115 +140,93 @@ func NewSubcommand(cmd interface{}) (*Subcommand, error) {
// shouldn't be called at all, rather you should use RegisterSubcommand.
func (sub *Subcommand) NeedsName() {
sub.StructName = sub.cmdType.Name()
flag, name := ParseFlag(sub.StructName)
if !flag.Is(Raw) {
name = lowerFirstLetter(name)
}
sub.Command = name
sub.Flag = flag
sub.Command = lowerFirstLetter(sub.StructName)
}
// FindCommand finds the command. Nil is returned if nothing is found. It's a
// better idea to not handle nil, as they would become very subtle bugs.
func (sub *Subcommand) FindCommand(methodName string) *CommandContext {
// FindCommand finds the MethodContext. It panics if methodName is not found.
func (sub *Subcommand) FindCommand(methodName string) *MethodContext {
for _, c := range sub.Commands {
if c.MethodName != methodName {
continue
if c.MethodName == methodName {
return c
}
return c
}
return nil
panic("Can't find method " + methodName)
}
// ChangeCommandInfo changes the matched methodName's Command and Description.
// Empty means unchanged. The returned bool is true when the method is found.
func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) bool {
for _, c := range sub.Commands {
if c.MethodName != methodName {
continue
}
if cmd != "" {
c.Command = cmd
}
if desc != "" {
c.Description = desc
}
return true
// Empty means unchanged. This function panics if methodName is not found.
func (sub *Subcommand) ChangeCommandInfo(methodName, cmd, desc string) {
var command = sub.FindCommand(methodName)
if cmd != "" {
command.Command = cmd
}
if desc != "" {
command.Description = desc
}
return false
}
func (sub *Subcommand) Help(indent string, hideAdmin bool) string {
if sub.Flag.Is(AdminOnly) && hideAdmin {
return ""
// Help calls the subcommand's Help() or auto-generates one with HelpGenerate()
// if the subcommand doesn't implement CanHelp. It doesn't show hidden commands
// by default.
func (sub *Subcommand) Help() string {
return sub.HelpShowHidden(false)
}
// HelpShowHidden does the same as Help(), except it will render hidden commands
// if the subcommand doesn't implement CanHelp and showHiddeen is true.
func (sub *Subcommand) HelpShowHidden(showHidden bool) string {
// Check if the subcommand implements CanHelp.
if sub.helper != nil {
return sub.helper()
}
return sub.HelpGenerate(showHidden)
}
// The header part:
var header string
// HelpGenerate auto-generates a help message. Use this only if you want to
// override the Subcommand's help, else use Help(). This function will show
// hidden commands if showHidden is true.
func (sub *Subcommand) HelpGenerate(showHidden bool) string {
// A wider space character.
const s = "\u2000"
if sub.Command != "" {
header += "**" + sub.Command + "**"
}
if sub.Description != "" {
if header != "" {
header += ": "
}
header += sub.Description
}
header += "\n"
// The commands part:
var commands = ""
var buf strings.Builder
for i, cmd := range sub.Commands {
if cmd.Flag.Is(AdminOnly) && hideAdmin {
if cmd.Hidden && !showHidden {
continue
}
switch {
case sub.Command != "" && cmd.Command != "":
commands += indent + sub.Command + " " + cmd.Command
case sub.Command != "":
commands += indent + sub.Command
default:
commands += indent + cmd.Command
}
buf.WriteString(sub.Command + " " + cmd.Command)
// Write the usages first.
for _, usage := range cmd.Usage() {
commands += " " + underline(usage)
}
// Is the last argument trailing? If so, append ellipsis.
if cmd.Variadic {
usage += "..."
}
// Is the last argument trailing? If so, append ellipsis.
if cmd.Variadic {
commands += "..."
// Uses \u2000, which is wider than a space.
buf.WriteString(s + "__" + usage + "__")
}
// Write the description if there's any.
if cmd.Description != "" {
commands += ": " + cmd.Description
buf.WriteString(": " + cmd.Description)
}
// Add a new line if this isn't the last command.
if i != len(sub.Commands)-1 {
commands += "\n"
buf.WriteByte('\n')
}
}
if commands == "" {
return ""
}
return buf.String()
}
return header + commands
// Hide marks a command as hidden, meaning it won't be shown in help and its
// UnknownCommand errors will be suppressed.
func (sub *Subcommand) Hide(methodName string) {
sub.FindCommand(methodName).Hidden = true
}
func (sub *Subcommand) reflectCommands() error {
@ -327,10 +268,9 @@ func (sub *Subcommand) InitCommands(ctx *Context) error {
v.Setup(sub)
}
// Finalize the subcommand:
for _, cmd := range sub.Commands {
// Inherit parent's flags
cmd.Flag |= sub.Flag
// See if struct implements CanHelper:
if v, ok := sub.command.(CanHelp); ok {
sub.helper = v.Help
}
return nil
@ -352,7 +292,7 @@ func (sub *Subcommand) fillStruct(ctx *Context) error {
return nil
}
return errors.New("No fields with *bot.Context found")
return errors.New("no fields with *bot.Context found")
}
func (sub *Subcommand) parseCommands() error {
@ -365,126 +305,114 @@ func (sub *Subcommand) parseCommands() error {
continue
}
methodT := method.Type()
numArgs := methodT.NumIn()
if numArgs == 0 {
// Doesn't meet the requirement for an event, continue.
methodT := sub.ptrType.Method(i)
if methodT.Name == "Setup" && methodT.Type == typeSetupFn {
continue
}
if methodT == typeSetupFn {
// Method is a setup method, continue.
cctx := parseMethod(method, methodT)
if cctx == nil {
continue
}
// Check number of returns:
numOut := methodT.NumOut()
// Returns can either be:
// Nothing - func()
// An error - func() error
// An error and something else - func() (T, error)
if numOut > 2 {
continue
// Append.
if cctx.event == typeMessageCreate {
sub.Commands = append(sub.Commands, cctx)
} else {
sub.Events = append(sub.Events, cctx)
}
// Check the last return's type if the method returns anything.
if numOut > 0 {
if i := methodT.Out(numOut - 1); i == nil || !i.Implements(typeIError) {
// Invalid, skip.
continue
}
}
var command = CommandContext{
method: sub.ptrType.Method(i),
value: method,
event: methodT.In(0), // parse event
Variadic: methodT.IsVariadic(),
}
// Parse the method name
flag, name := ParseFlag(command.method.Name)
// Set the method name, command, and flag:
command.MethodName = name
command.Command = name
command.Flag = flag
// Check if Raw is enabled for command:
if !flag.Is(Raw) {
command.Command = lowerFirstLetter(name)
}
// Middlewares shouldn't even have arguments.
if flag.Is(Middleware) {
sub.mwMethods = append(sub.mwMethods, &command)
continue
}
// TODO: allow more flexibility
if command.event != typeMessageCreate || flag.Is(Hidden) {
sub.Events = append(sub.Events, &command)
continue
}
// See if we know the first return type, if error's return is the
// second:
if numOut > 1 {
switch t := methodT.Out(0); t {
case typeString, typeEmbed, typeSend:
// noop, passes
default:
continue
}
}
// If a plumb method has been found:
if sub.plumb {
continue
}
// If the method only takes an event:
if numArgs == 1 {
sub.Commands = append(sub.Commands, &command)
continue
}
command.Arguments = make([]Argument, 0, numArgs)
// Fill up arguments. This should work with cusP and manP
for i := 1; i < numArgs; i++ {
t := methodT.In(i)
a, err := newArgument(t, command.Variadic)
if err != nil {
return errors.Wrap(err, "Error parsing argument "+t.String())
}
command.Arguments = append(command.Arguments, *a)
// We're done if the type accepts multiple arguments.
if a.custom != nil || a.manual != nil {
command.Variadic = true // treat as variadic
break
}
}
// If the current event is a plumb event:
if flag.Is(Plumb) {
command.Command = "" // plumbers don't have names
sub.Commands = []*CommandContext{&command}
sub.plumb = true
continue
}
// Append
sub.Commands = append(sub.Commands, &command)
}
return nil
}
// AddMiddleware adds a middleware into multiple or all methods, including
// commands and events. Multiple method names can be comma-delimited. For all
// methods, use a star (*). The given middleware argument can either be a
// function with one of the allowed methods or a *MiddlewareContext.
//
// Allowed function signatures
//
// Below are the acceptable function signatures that would be parsed as a proper
// middleware. A return value of type T will be ignored. If the given function
// is invalid, then this method will panic.
//
// func(<AnyEvent>) (T, error)
// func(<AnyEvent>) error
// func(<AnyEvent>)
//
// Note that although technically all of the above function signatures are
// acceptable, one should almost always return only an error.
func (sub *Subcommand) AddMiddleware(methodName string, middleware interface{}) {
var mw *MiddlewareContext
// Allow *MiddlewareContext to be passed into.
if v, ok := middleware.(*MiddlewareContext); ok {
mw = v
} else {
mw = ParseMiddleware(middleware)
}
// Parse method name:
for _, method := range strings.Split(methodName, ",") {
// Trim space.
if method = strings.TrimSpace(method); method == "*" {
// Append middleware to global middleware slice.
sub.globalmws = append(sub.globalmws, mw)
continue
}
// Append middleware to that individual function.
sub.findMethod(method).addMiddleware(mw)
}
}
func (sub *Subcommand) findMethod(name string) *MethodContext {
for _, ev := range sub.Events {
if ev.MethodName == name {
return ev
}
}
return sub.FindCommand(name)
}
func (sub *Subcommand) eventCallers(evT reflect.Type) (callers []caller) {
// Search for global middlewares.
for _, mw := range sub.globalmws {
if mw.isEvent(evT) {
callers = append(callers, mw)
}
}
// Search for specific handlers.
for _, cctx := range sub.Events {
// We only take middlewares and callers if the event matches and is not
// a MessageCreate. The other function already handles that.
if cctx.isEvent(evT) {
// Add the command's middlewares first.
for _, mw := range cctx.middlewares {
// Concrete struct to interface conversion done implicitly.
callers = append(callers, mw)
}
callers = append(callers, cctx)
}
}
return
}
// SetPlumb sets the method as the plumbed command.
func (sub *Subcommand) SetPlumb(methodName string) {
sub.plumbed = sub.FindCommand(methodName)
}
// AddAliases add alias(es) to specific command (defined with commandName).
func (sub *Subcommand) AddAliases(commandName string, aliases ...string) {
// Get command
command := sub.FindCommand(commandName)
// Write new listing of aliases
command.Aliases = append(command.Aliases, aliases...)
}
func lowerFirstLetter(name string) string {
return strings.ToLower(string(name[0])) + name[1:]
}

View file

@ -1,9 +1,22 @@
package bot
import (
"strings"
"testing"
)
func TestUnderline(t *testing.T) {
HelpUnderline = false
if underline("astolfo") != "astolfo" {
t.Fatal("Unexpected underlining with HelpUnderline = false")
}
HelpUnderline = true
if underline("arikawa hime") != "__arikawa hime__" {
t.Fatal("Unexpected normal style with HelpUnderline = true")
}
}
func TestNewSubcommand(t *testing.T) {
_, err := NewSubcommand(&testc{})
if err != nil {
@ -29,8 +42,11 @@ func TestSubcommand(t *testing.T) {
}
// !!! CHANGE ME
if len(sub.Commands) != 8 {
t.Fatal("invalid ctx.commands len", len(sub.Commands))
if len(sub.Commands) < 8 {
t.Fatal("too low sub.Methods len", len(sub.Commands))
}
if len(sub.Events) < 1 {
t.Fatal("No events found.")
}
var (
@ -58,13 +74,6 @@ func TestSubcommand(t *testing.T) {
if len(this.Arguments) != 0 {
t.Fatal("expected 0 arguments, got non-zero")
}
case "noop", "getCounter", "variadic", "trailCustom", "content":
// Found, but whatever
}
if this.event != typeMessageCreate {
t.Fatal("invalid event type:", this.event.String())
}
}
@ -81,10 +90,29 @@ func TestSubcommand(t *testing.T) {
}
})
t.Run("init commands", func(t *testing.T) {
ctx := &Context{}
if err := sub.InitCommands(ctx); err != nil {
t.Fatal("Failed to init commands:", err)
}
})
t.Run("help commands", func(t *testing.T) {
if h := sub.Help("", false); h == "" {
h := sub.Help()
if h == "" {
t.Fatal("Empty subcommand help?")
}
if strings.Contains(h, "hidden") {
t.Fatal("Hidden command shown in help:\n", h)
}
})
t.Run("change command", func(t *testing.T) {
sub.ChangeCommandInfo("Noop", "crossdressing", "best")
if h := sub.Help(); !strings.Contains(h, "crossdressing: best") {
t.Fatal("Changed command is not in help.")
}
})
}

View file

@ -6,33 +6,45 @@ import (
"github.com/diamondburned/arikawa/utils/json"
)
// https://discord.com/developers/docs/resources/audit-log#audit-log-object
type AuditLog struct {
// List of webhooks found in the audit log
// Webhooks is the list of webhooks found in the audit log.
Webhooks []Webhook `json:"webhooks"`
// List of users found in the audit log
// Users is the list of users found in the audit log.
Users []User `json:"users"`
// List of audit log entries
// Entries is the list of audit log entries.
Entries []AuditLogEntry `json:"audit_log_entries"`
// List of partial integration objects, only ID, Name, Type, and Account
// Integrations is a list ist of partial integration objects (only ID,
// Name, Type, and Account).
Integrations []Integration `json:"integrations"`
}
// AuditLogEntry is a single entry in the audit log.
//
// https://discord.com/developers/docs/resources/audit-log#audit-log-entry-object
type AuditLogEntry struct {
ID Snowflake `json:"id"`
UserID Snowflake `json:"user_id"`
TargetID string `json:"target_id,omitempty"`
// ID is the id of the entry.
ID AuditLogEntryID `json:"id"`
// TargetID is the id of the affected entity (webhook, user, role, etc.).
TargetID string `json:"target_id,omitempty"`
// Changes are the changes made to the TargetID.
Changes []AuditLogChange `json:"changes,omitempty"`
// UserID is the id of the user who made the changes.
UserID UserID `json:"user_id"`
// ActionType is the type of action that occurred.
ActionType AuditLogEvent `json:"action_type"`
Changes []AuditLogChange `json:"changes,omitempty"`
Options AuditEntryInfo `json:"options,omitempty"`
Reason string `json:"reason,omitempty"`
// Options contains additional info for certain action types.
Options AuditEntryInfo `json:"options,omitempty"`
// Reason is the reason for the change (0-512 characters).
Reason string `json:"reason,omitempty"`
}
// AuditLogEvent is the type of audit log action that occured.
// AuditLogEvent is the type of audit log action that occurred.
type AuditLogEvent uint8
// https://discord.com/developers/docs/resources/audit-log#audit-log-entry-object-audit-log-events
const (
GuildUpdate AuditLogEvent = 1
ChannelCreate AuditLogEvent = 10
@ -71,22 +83,44 @@ const (
IntegrationDelete AuditLogEvent = 82
)
// https://discord.com/developers/docs/resources/audit-log#audit-log-entry-object-optional-audit-entry-info
type AuditEntryInfo struct {
// MEMBER_PRUNE
// DeleteMemberDays is the number of days after which inactive members were
// kicked.
//
// Events: MEMBER_PRUNE
DeleteMemberDays string `json:"delete_member_days,omitempty"`
// MEMBER_PRUNE
// MembersRemoved is the number of members removed by the prune.
//
// Events: MEMBER_PRUNE
MembersRemoved string `json:"members_removed,omitempty"`
// MEMBER_MOVE & MESSAGE_PIN & MESSAGE_UNPIN & MESSAGE_DELETE
ChannelID Snowflake `json:"channel_id,omitempty"`
// MESSAGE_PIN & MESSAGE_UNPIN
MessageID Snowflake `json:"message_id,omitempty"`
// MESSAGE_DELETE & MESSAGE_BULK_DELETE & MEMBER_DISCONNECT & MEMBER_MOVE
// ChannelID is the id of the channel in which the entities were targeted.
//
// Events: MEMBER_MOVE, MESSAGE_PIN, MESSAGE_UNPIN, MESSAGE_DELETE
ChannelID ChannelID `json:"channel_id,omitempty"`
// MessagesID is the id of the message that was targeted.
//
// Events: MESSAGE_PIN, MESSAGE_UNPIN
MessageID MessageID `json:"message_id,omitempty"`
// Count is the number of entities that were targeted.
//
// Events: MESSAGE_DELETE, MESSAGE_BULK_DELETE, MEMBER_DISCONNECT,
// MEMBER_MOVE
Count string `json:"count,omitempty"`
// CHANNEL_OVERWRITE_CREATE & CHANNEL_OVERWRITE_UPDATE & CHANNEL_OVERWRITE_DELETE
// ID is the id of the overwritten entity.
//
// Events: CHANNEL_OVERWRITE_CREATE, CHANNEL_OVERWRITE_UPDATE,
// CHANNEL_OVERWRITE_DELETE
ID Snowflake `json:"id,omitempty"`
// CHANNEL_OVERWRITE_CREATE & CHANNEL_OVERWRITE_UPDATE & CHANNEL_OVERWRITE_DELETE
// Type is the type of overwritten entity.
//
// Events: CHANNEL_OVERWRITE_CREATE, CHANNEL_OVERWRITE_UPDATE,
// CHANNEL_OVERWRITE_DELETE
Type ChannelOverwritten `json:"type,omitempty"`
// CHANNEL_OVERWRITE_CREATE & CHANNEL_OVERWRITE_UPDATE & CHANNEL_OVERWRITE_DELETE
// RoleName is the name of the role if type is "role".
//
// Events: CHANNEL_OVERWRITE_CREATE, CHANNEL_OVERWRITE_UPDATE,
// CHANNEL_OVERWRITE_DELETE
RoleName string `json:"role_name,omitempty"`
}
@ -117,8 +151,8 @@ const (
// return errors.New("not owner ID")
// }
//
// // We know these are snowflakes because the comment said so for AuditGuildOwnerID.
// var oldOwnerID, newOwnerID discord.Snowflake
// // We know these are UserIDs because the comment said so for AuditGuildOwnerID.
// var oldOwnerID, newOwnerID discord.UserID
// if err := change.UnmarshalValues(&oldOwnerID, &newOwnerID); err != nil {
// return err
// }
@ -126,141 +160,260 @@ const (
// log.Println("Transferred ownership from user", oldOwnerID, "to", newOwnerID)
//
type AuditLogChange struct {
Key string `json:"key"`
NewValue json.Raw `json:"new_value,omitempty"` // nullable
OldValue json.Raw `json:"old_value,omitempty"` // nullable
// Key is the name of audit log change key.
Key AuditLogChangeKey `json:"key"`
// NewValue is the new value of the key.
NewValue json.Raw `json:"new_value,omitempty"`
// OldValue is the old value of the key.
OldValue json.Raw `json:"old_value,omitempty"`
}
// UnmarshalValues unmarshals the values of the AuditLogChange into the passed
// interfaces.
func (a AuditLogChange) UnmarshalValues(old, new interface{}) error {
if err := a.NewValue.UnmarshalTo(new); err != nil {
return errors.Wrap(err, "Failed to unmarshal old value")
return errors.Wrap(err, "failed to unmarshal old value")
}
if err := a.OldValue.UnmarshalTo(old); err != nil {
return errors.Wrap(err, "Failed to unmarshal new value")
return errors.Wrap(err, "failed to unmarshal new value")
}
return nil
}
type AuditLogChangeKey string
// https://discord.com/developers/docs/resources/audit-log#audit-log-change-object-audit-log-change-key
const (
// Type string, name changed
// AuditGuildName gets sent if the guild's name was changed.
//
// Type: string
AuditGuildName AuditLogChangeKey = "name"
// Type Hash, icon changed
// AuditGuildIconHash gets sent if the guild's icon was changed.
//
// Type: Hash
AuditGuildIconHash AuditLogChangeKey = "icon_hash"
// Type Hash, invite splash page artwork changed
// AuditGuildSplashHash gets sent if the guild's invite splash page artwork
// was changed.
//
// Type: Hash
AuditGuildSplashHash AuditLogChangeKey = "splash_hash"
// Type Snowflake, owner changed
// AuditGuildOwnerID gets sent if the guild's owner changed.
//
// Type: UserID
AuditGuildOwnerID AuditLogChangeKey = "owner_id"
// Type string, region changed
// AuditGuildRegion gets sent if the guild's region changed.
//
// Type: string
AuditGuildRegion AuditLogChangeKey = "region"
// Type Snowflake, afk channel changed
// AuditGuildAFKChannelID gets sent if the guild's afk channel changed.
//
// Type: ChannelID
AuditGuildAFKChannelID AuditLogChangeKey = "afk_channel_id"
// Type Seconds, afk timeout duration changed
// AuditGuildAFKTimeout gets sent if the guild's afk timeout duration
// changed.
//
// Type: Seconds
AuditGuildAFKTimeout AuditLogChangeKey = "afk_timeout"
// Type int, two-factor auth requirement changed
// AuditGuildMFA gets sent if the two-factor auth requirement changed.
//
// Type: MFALevel
AuditGuildMFA AuditLogChangeKey = "mfa_level"
// Type Verification, required verification level changed
// AuditGuildVerification gets sent if the guild's required verification
// level changed
//
// Type: Verification
AuditGuildVerification AuditLogChangeKey = "verification_level"
// Type ExplicitFilter, change in whose messages are scanned and deleted for
// explicit content in the server
// AuditGuildExplicitFilter gets sent if there was a change in whose
// messages are scanned and deleted for explicit content in the server.
//
// Type: ExplicitFilter
AuditGuildExplicitFilter AuditLogChangeKey = "explicit_content_filter"
// Type Notification, default message notification level changed
// AuditGuildNotification gets sent if the default message notification
// level changed.
//
// Type: Notification
AuditGuildNotification AuditLogChangeKey = "default_message_notifications"
// Type string, guild invite vanity URL changed
// AuditGuildVanityURLCode gets sent if the guild invite vanity URL
// changed.
//
// Type: string
AuditGuildVanityURLCode AuditLogChangeKey = "vanity_url_code"
// Type []Role{ID, Name}, new role added
// AuditGuildRoleAdd gets sent if a new role was added.
//
// Type: []Role{ID, Name}
AuditGuildRoleAdd AuditLogChangeKey = "$add"
// Type []Role{ID, Name}, role removed
// AuditGuildRoleRemove gets sent if a role was removed.
//
// Type: []Role{ID, Name}
AuditGuildRoleRemove AuditLogChangeKey = "$remove"
// Type int, change in number of days after which inactive and
// role-unassigned members are kicked
// AuditGuildPruneDeleteDays gets sent if there was a change in number of
// days after which inactive and role-unassigned members are kicked.
//
// Type: int
AuditGuildPruneDeleteDays AuditLogChangeKey = "prune_delete_days"
// Type bool, server widget enabled/disable
// AuditGuildWidgetEnabled gets sent if the guild's widget was
// enabled/disabled.
//
// Type: bool
AuditGuildWidgetEnabled AuditLogChangeKey = "widget_enabled"
// Type Snowflake, channel ID of the server widget changed
// AuditGuildWidgetChannelID gets sent if the channel ID of the guild
// widget changed.
//
// Type: ChannelID
AuditGuildWidgetChannelID AuditLogChangeKey = "widget_channel_id"
// Type Snowflake, ID of the system channel changed
// AuditGuildSystemChannelID gets sent if the ID of the guild's system
// channel changed.
//
// Type: ChannelID
AuditGuildSystemChannelID AuditLogChangeKey = "system_channel_id"
)
const (
// Type int, text or voice channel position changed
// AuditChannelPosition gets sent if a text or voice channel position was
// changed.
//
// Type: int
AuditChannelPosition AuditLogChangeKey = "position"
// Type string, text channel topic changed
// AuditChannelTopic gets sent if the text channel topic changed.
//
// Type: string
AuditChannelTopic AuditLogChangeKey = "topic"
// Type uint, voice channel bitrate changed
// AuditChannelBitrate gets sent if the voice channel bitrate changed.
//
// Type: uint
AuditChannelBitrate AuditLogChangeKey = "bitrate"
// Type []Overwrite, permissions on a channel changed
// AuditChannelPermissionOverwrites gets sent if the permissions on a
// channel changed.
//
// Type: []Overwrite
AuditChannelPermissionOverwrites AuditLogChangeKey = "permission_overwrites"
// Type bool, channel NSFW restriction changed
// AuditChannelNSFW gets sent if the channel NSFW restriction changed.
//
// Type: bool
AuditChannelNSFW AuditLogChangeKey = "nsfw"
// Type Snowflake, application ID of the added or removed webhook or bot
// AuditChannelApplicationID contains the application ID of the added or
// removed webhook or bot.
//
// Type: AppID
AuditChannelApplicationID AuditLogChangeKey = "application_id"
// Type Seconds, amount of seconds a user has to wait before sending another
// message changed
// AuditChannelRateLimitPerUser gets sent if the amount of seconds a user
// has to wait before sending another message changed.
//
// Type: Seconds
AuditChannelRateLimitPerUser AuditLogChangeKey = "rate_limit_per_user"
)
const (
// Type Permissions, permissions for a role changed
// AuditRolePermissions gets sent if the permissions for a role changed.
//
// Type: Permissions
AuditRolePermissions AuditLogChangeKey = "permissions"
// Type Color, role color changed
// AuditRoleColor gets sent if the role color changed.
//
// Type: Color
AuditRoleColor AuditLogChangeKey = "color"
// Type bool, role is now displayed/no longer displayed separate from online
// users
// AuditRoleHoist gets sent if the role is now displayed/no longer
// displayed separate from online users.
//
// Type: bool
AuditRoleHoist AuditLogChangeKey = "hoist"
// Type bool, role is now mentionable/unmentionable
// AuditRoleMentionable gets sent if a role is now
// mentionable/unmentionable.
//
// Type: bool
AuditRoleMentionable AuditLogChangeKey = "mentionable"
// Type Permissions, a permission on a text or voice channel was allowed for
// a role
// AuditRoleAllow gets sent if a permission on a text or voice channel was
// allowed for a role.
//
// Type: Permissions
AuditRoleAllow AuditLogChangeKey = "allow"
// Type Permissions, a permission on a text or voice channel was denied for
// a role
// AuditRoleDeny gets sent if a permission on a text or voice channel was
// denied for a role.
//
// Type: Permissions
AuditRoleDeny AuditLogChangeKey = "deny"
)
const (
// Type string, invite code changed
// AuditInviteCode gets sent if an invite code changed.
//
// Type: string
AuditInviteCode AuditLogChangeKey = "code"
// Type Snowflake, channel for invite code changed
// AuditInviteChannelID gets sent if the channel for an invite code
// changed.
//
// Type: ChannelID
AuditInviteChannelID AuditLogChangeKey = "channel_id"
// Type Snowflake, person who created invite code changed
// AuditInviteInviterID specifies the person who created invite code
// changed.
//
// Type: UserID
AuditInviteInviterID AuditLogChangeKey = "inviter_id"
// Type int, change to max number of times invite code can be used
// AuditInviteMaxUses specifies the change to max number of times invite
// code can be used.
//
// Type: int
AuditInviteMaxUses AuditLogChangeKey = "max_uses"
// Type int, number of times invite code used changed
// AuditInviteUses specifies the number of times invite code used changed.
//
// Type: int
AuditInviteUses AuditLogChangeKey = "uses"
// Type Seconds, how long invite code lasts changed
// AuditInviteMaxAge specifies the how long invite code lasts
// changed.
//
// Type: Seconds
AuditInviteMaxAge AuditLogChangeKey = "max_age"
// Type bool, invite code is temporary/never expires
// AuditInviteTemporary specifies if an invite code is temporary/never
// expires.
//
// Type: bool
AuditInviteTemporary AuditLogChangeKey = "temporary"
)
const (
// Type bool, user server deafened/undeafened
// AuditUserDeaf specifies if the user was server deafened/undeafened.
//
// Type: bool
AuditUserDeaf AuditLogChangeKey = "deaf"
// Type bool, user server muted/unmuted
// AuditUserMute specifies if the user was server muted/unmuted.
//
// Type: bool
AuditUserMute AuditLogChangeKey = "mute"
// Type string, user nickname changed
// AuditUserNick specifies the new nickname of the user.
//
// Type: string
AuditUserNick AuditLogChangeKey = "nick"
// Type Hash, user avatar changed
// AuditUserAvatar specifies the hash of the new user avatar.
//
// Type: Hash
AuditUserAvatarHash AuditLogChangeKey = "avatar_hash"
)
const (
// Type Snowflake, the ID of the changed entity - sometimes used in
// conjunction with other keys
// AuditAnyID specifies the ID of the changed entity - sometimes used in
// conjunction with other keys.
//
// Type: Snowflake
AuditAnyID AuditLogChangeKey = "id"
// Type int (channel type) or string, type of entity created
// AuditAnyType is the type of the entity created.
// Type ChannelType or string
AuditAnyType AuditLogChangeKey = "type"
)
const (
// Type bool, integration emoticons enabled/disabled
// AuditIntegrationEnableEmoticons gets sent if the integration emoticons
// were enabled/disabled.
//
// Type: bool
AuditIntegrationEnableEmoticons AuditLogChangeKey = "enable_emoticons"
// Type int, integration expiring subscriber behavior changed
// AuditIntegrationExpireBehavior gets sent if the integration expiring
// subscriber behavior changed.
//
// Type: ExpireBehavior
AuditIntegrationExpireBehavior AuditLogChangeKey = "expire_behavior"
// Type int, integration expire grace period changed
// AuditIntegrationExpireGracePeriod gets sent if the integration expire
// grace period changed.
//
// Type: int
AuditIntegrationExpireGracePeriod AuditLogChangeKey = "expire_grace_period"
)

View file

@ -1,83 +1,148 @@
package discord
import "github.com/diamondburned/arikawa/utils/json"
// https://discord.com/developers/docs/resources/channel#channel-object
type Channel struct {
ID Snowflake `json:"id,string"`
// ID is the id of this channel.
ID ChannelID `json:"id,string"`
// Type is the type of channel.
Type ChannelType `json:"type"`
// GuildID is the id of the guild.
GuildID GuildID `json:"guild_id,string,omitempty"`
// Fields below may not appear
GuildID Snowflake `json:"guild_id,string,omitempty"`
Position int `json:"position,omitempty"`
Name string `json:"name,omitempty"` // 2-100 chars
Topic string `json:"topic,omitempty"` // 0-1024 chars
NSFW bool `json:"nsfw"`
Icon Hash `json:"icon,omitempty"`
// Direct Messaging fields
DMOwnerID Snowflake `json:"owner_id,string,omitempty"`
DMRecipients []User `json:"recipients,omitempty"`
// AppID of the group DM creator if it's bot-created
AppID Snowflake `json:"application_id,string,omitempty"`
// ID of the category the channel is in, if any.
CategoryID Snowflake `json:"parent_id,string,omitempty"`
LastPinTime Timestamp `json:"last_pin_timestamp,omitempty"`
// Explicit permission overrides for members and roles.
// Position is the sorting position of the channel.
Position int `json:"position,omitempty"`
// Permissions are the explicit permission overrides for members and roles.
Permissions []Overwrite `json:"permission_overwrites,omitempty"`
// ID of the last message, may not point to a valid one.
LastMessageID Snowflake `json:"last_message_id,string,omitempty"`
// Slow mode duration. Bots and people with "manage_messages" or
// "manage_channel" permissions are unaffected.
// Name is the name of the channel (2-100 characters).
Name string `json:"name,omitempty"`
// Topic is the channel topic (0-1024 characters).
Topic string `json:"topic,omitempty"`
// NSFW specifies whether the channel is nsfw.
NSFW bool `json:"nsfw"`
// LastMessageID is the id of the last message sent in this channel (may
// not point to an existing or valid message).
LastMessageID MessageID `json:"last_message_id,string,omitempty"`
// VoiceBitrate is the bitrate (in bits) of the voice channel.
VoiceBitrate uint `json:"bitrate,omitempty"`
// VoiceUserLimit is the user limit of the voice channel.
VoiceUserLimit uint `json:"user_limit,omitempty"`
// UserRateLimit is the amount of seconds a user has to wait before sending
// another message (0-21600). Bots, as well as users with the permission
// manage_messages or manage_channel, are unaffected.
UserRateLimit Seconds `json:"rate_limit_per_user,omitempty"`
// Voice, so GuildVoice only
VoiceBitrate uint `json:"bitrate,omitempty"`
VoiceUserLimit uint `json:"user_limit,omitempty"`
// DMRecipients are the recipients of the DM.
DMRecipients []User `json:"recipients,omitempty"`
// Icon is the icon hash.
Icon Hash `json:"icon,omitempty"`
// DMOwnerID is the id of the DM creator.
DMOwnerID UserID `json:"owner_id,string,omitempty"`
// AppID is the application id of the group DM creator if it is
// bot-created.
AppID AppID `json:"application_id,string,omitempty"`
// CategoryID is the id of the parent category for a channel (each parent
// category can contain up to 50 channels).
CategoryID ChannelID `json:"parent_id,string,omitempty"`
// LastPinTime is when the last pinned message was pinned.
LastPinTime Timestamp `json:"last_pin_timestamp,omitempty"`
}
// Mention returns a mention of the channel.
func (ch Channel) Mention() string {
return "<#" + ch.ID.String() + ">"
return ch.ID.Mention()
}
// IconURL returns the icon of the channel. This function will only return
// something if ch.Icon is not empty.
// IconURL returns the URL to the channel icon in the PNG format.
// An empty string is returned if there's no icon.
func (ch Channel) IconURL() string {
return ch.IconURLWithType(PNGImage)
}
// IconURLWithType returns the URL to the channel icon using the passed
// ImageType. An empty string is returned if there's no icon.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (ch Channel) IconURLWithType(t ImageType) string {
if ch.Icon == "" {
return ""
}
return "https://cdn.discordapp.com/channel-icons/" +
ch.ID.String() + "/" + ch.Icon + ".png"
ch.ID.String() + "/" + t.format(ch.Icon)
}
type ChannelType uint8
// https://discord.com/developers/docs/resources/channel#channel-object-channel-types
var (
GuildText ChannelType = 0
// GuildText is a text channel within a server.
GuildText ChannelType = 0
// DirectMessage is a direct message between users.
DirectMessage ChannelType = 1
GuildVoice ChannelType = 2
GroupDM ChannelType = 3
// GuildVoice is a voice channel within a server.
GuildVoice ChannelType = 2
// GroupDM is a direct message between multiple users.
GroupDM ChannelType = 3
// GuildCategory is an organizational category that contains up to 50
// channels.
GuildCategory ChannelType = 4
GuildNews ChannelType = 5
GuildStore ChannelType = 6
// GuildNews is a channel that users can follow and crosspost into their
// own server.
GuildNews ChannelType = 5
// GuildStore is a channel in which game developers can sell their game on
// Discord.
GuildStore ChannelType = 6
)
// https://discord.com/developers/docs/resources/channel#overwrite-object
type Overwrite struct {
ID Snowflake `json:"id,string,omitempty"`
Type OverwriteType `json:"type"`
Allow Permissions `json:"allow"`
Deny Permissions `json:"deny"`
// ID is the role or user id.
ID Snowflake `json:"id"`
// Type is either "role" or "member".
Type OverwriteType `json:"type"`
// Allow is a permission bit set for granted permissions.
Allow Permissions `json:"allow,string"`
// Deny is a permission bit set for denied permissions.
Deny Permissions `json:"deny,string"`
}
// UnmarshalJSON unmarshals the passed json data into the Overwrite.
// This is necessary because Discord has different names for fields when
// sending than receiving.
func (o *Overwrite) UnmarshalJSON(data []byte) (err error) {
var recv struct {
ID Snowflake `json:"id"`
Type OverwriteType `json:"type"`
Allow Permissions `json:"allow_new,string"`
Deny Permissions `json:"deny_new,string"`
}
err = json.Unmarshal(data, &recv)
if err != nil {
return
}
o.ID = recv.ID
o.Type = recv.Type
o.Allow = recv.Allow
o.Deny = recv.Deny
return
}
type OverwriteType string
const (
OverwriteRole OverwriteType = "role"
// OverwriteRole is an overwrite for a role.
OverwriteRole OverwriteType = "role"
// OverwriteMember is an overwrite for a member.
OverwriteMember OverwriteType = "member"
)

View file

@ -3,22 +3,50 @@ package discord
import "strings"
type Emoji struct {
ID Snowflake `json:"id,string"` // 0 for Unicode emojis
Name string `json:"name"`
ID EmojiID `json:"id,string"` // NullSnowflake for unicode emojis
Name string `json:"name"`
// These fields are optional
RoleIDs []Snowflake `json:"roles,omitempty"`
User User `json:"user,omitempty"`
RoleIDs []RoleID `json:"roles,omitempty"`
User User `json:"user,omitempty"`
RequireColons bool `json:"require_colons,omitempty"`
Managed bool `json:"managed,omitempty"`
Animated bool `json:"animated,omitempty"`
}
// EmojiURL returns the URL of the emoji and auto-detects a suitable type.
//
// This will only work for custom emojis.
func (e Emoji) EmojiURL() string {
if e.Animated {
return e.EmojiURLWithType(GIFImage)
}
return e.EmojiURLWithType(PNGImage)
}
// EmojiURLWithType returns the URL to the emoji's image.
//
// This will only work for custom emojis.
//
// Supported ImageTypes: PNG, GIF
func (e Emoji) EmojiURLWithType(t ImageType) string {
if e.ID.IsNull() {
return ""
}
if t == AutoImage {
return e.EmojiURL()
}
return "https://cdn.discordapp.com/emojis/" + t.format(e.ID.String())
}
// APIString returns a string usable for sending over to the API.
func (e Emoji) APIString() string {
if e.ID == 0 {
if !e.ID.IsValid() {
return e.Name // is unicode
}

View file

@ -1,192 +1,447 @@
package discord
// https://discord.com/developers/docs/resources/guild#guild-object
type Guild struct {
ID Snowflake `json:"id,string"`
Name string `json:"name"`
Icon Hash `json:"icon"`
Splash Hash `json:"splash,omitempty"` // server invite bg
// ID is the guild id.
ID GuildID `json:"id,string"`
// Name is the guild name (2-100 characters, excluding trailing and leading
// whitespace).
Name string `json:"name"`
// Icon is the icon hash.&nullableUint64{}
Icon Hash `json:"icon"`
// Splash is the splash hash.
Splash Hash `json:"splash,omitempty"`
// DiscoverySplash is the discovery splash hash.
//
// Only present for guilds with the "DISCOVERABLE" feature.
DiscoverySplash Hash `json:"discovery_splash,omitempty"`
Owner bool `json:"owner,omitempty"` // self is owner
OwnerID Snowflake `json:"owner_id,string"`
// Owner is true if the user is the owner of the guild.
Owner bool `json:"owner,omitempty"`
// OwnerID is the id of owner.
OwnerID UserID `json:"owner_id,string"`
Permissions Permissions `json:"permissions,omitempty"`
// Permissions are the total permissions for the user in the guild
// (excludes overrides).
Permissions Permissions `json:"permissions_new,omitempty,string"`
// VoiceRegion is the voice region id for the guild.
VoiceRegion string `json:"region"`
AFKChannelID Snowflake `json:"afk_channel_id,string,omitempty"`
AFKTimeout Seconds `json:"afk_timeout"`
// AFKChannelID is the id of the afk channel.
AFKChannelID ChannelID `json:"afk_channel_id,string,omitempty"`
// AFKTimeout is the afk timeout in seconds.
AFKTimeout Seconds `json:"afk_timeout"`
Embeddable bool `json:"embed_enabled,omitempty"`
EmbedChannelID Snowflake `json:"embed_channel_id,string,omitempty"`
// Embeddable is true if the server widget is enabled.
//
// Deprecated: replaced with WidgetEnabled
Embeddable bool `json:"embed_enabled,omitempty"`
// EmbedChannelID is the channel id that the widget will generate an invite
// to, or null if set to no invite .
//
// Deprecated: replaced with WidgetChannelID
EmbedChannelID ChannelID `json:"embed_channel_id,string,omitempty"`
Verification Verification `json:"verification_level"`
Notification Notification `json:"default_message_notifications"`
// Verification is the verification level required for the guild.
Verification Verification `json:"verification_level"`
// Notification is the default message notifications level.
Notification Notification `json:"default_message_notifications"`
// ExplicitFilter is the explicit content filter level.
ExplicitFilter ExplicitFilter `json:"explicit_content_filter"`
Roles []Role `json:"roles"`
Emojis []Emoji `json:"emojis"`
// Roles are the roles in the guild.
Roles []Role `json:"roles"`
// Emojis are the custom guild emojis.
Emojis []Emoji `json:"emojis"`
// Features are the enabled guild features.
Features []GuildFeature `json:"guild_features"`
// MFA is the required MFA level for the guild.
MFA MFALevel `json:"mfa"`
AppID Snowflake `json:"application_id,string,omitempty"`
// AppID is the application id of the guild creator if it is bot-created.
//
// This field is nullable.
AppID AppID `json:"application_id,string,omitempty"`
// Widget is true if the server widget is enabled.
Widget bool `json:"widget_enabled,omitempty"`
// WidgetChannelID is the channel id that the widget will generate an
// invite to, or null if set to no invite.
WidgetChannelID ChannelID `json:"widget_channel_id,string,omitempty"`
WidgetChannelID Snowflake `json:"widget_channel_id,string,omitempty"`
SystemChannelID Snowflake `json:"system_channel_id,string,omitempty"`
// SystemChannelID is the the id of the channel where guild notices such as
// welcome messages and boost events are posted.
SystemChannelID ChannelID `json:"system_channel_id,string,omitempty"`
// SystemChannelFlags are the system channel flags.
SystemChannelFlags SystemChannelFlags `json:"system_channel_flags"`
// It's DefaultMaxPresences when MaxPresences is 0.
// RulesChannelID is the id of the channel where guilds with the "PUBLIC"
// feature can display rules and/or guidelines.
RulesChannelID ChannelID `json:"rules_channel_id"`
// MaxPresences is the maximum number of presences for the guild (the
// default value, currently 25000, is in effect when null is returned, so
// effectively when this field is 0).
MaxPresences uint64 `json:"max_presences,omitempty"`
MaxMembers uint64 `json:"max_members,omitempty"`
// MaxMembers the maximum number of members for the guild.
MaxMembers uint64 `json:"max_members,omitempty"`
// VanityURL is the the vanity url code for the guild.
VanityURLCode string `json:"vanity_url_code,omitempty"`
Description string `json:"description,omitempty"`
Banner Hash `json:"banner,omitempty"`
// Description is the description for the guild, if the guild is
// discoverable.
Description string `json:"description,omitempty"`
NitroBoost NitroBoost `json:"premium_tier"`
NitroBoosters uint64 `json:"premium_subscription_count,omitempty"`
// Banner is the banner hash.
Banner Hash `json:"banner,omitempty"`
// Defaults to en-US, only set if guild has DISCOVERABLE
// NitroBoost is the premium tier (Server Boost level).
NitroBoost NitroBoost `json:"premium_tier"`
// NitroBoosters is the number of boosts this guild currently has.
NitroBoosters uint64 `json:"premium_subscription_count,omitempty"`
// PreferredLocale is the the preferred locale of a guild with the "PUBLIC"
// feature; used in server discovery and notices from Discord. Defaults to
// "en-US".
PreferredLocale string `json:"preferred_locale"`
// Only presented if WithCounts is true.
ApproximateMembers uint64 `json:"approximate_member_count,omitempty"`
// PublicUpdatesChannelID is the id of the channel where admins and
// moderators of guilds with the "PUBLIC" feature receive notices from
// Discord.
PublicUpdatesChannelID ChannelID `json:"public_updates_channel_id"`
// MaxVideoChannelUsers is the maximum amount of users in a video channel.
MaxVideoChannelUsers uint64 `json:"max_video_channel_users,omitempty"`
// ApproximateMembers is the approximate number of members in this guild, returned from the GET /guild/<id> endpoint when with_counts is true
ApproximateMembers uint64 `json:"approximate_member_count,omitempty"`
// ApproximatePresences is the approximate number of non-offline members in
// this guild, returned by the GuildWithCount method.
ApproximatePresences uint64 `json:"approximate_presence_count,omitempty"`
}
// IconURL returns the URL to the guild icon. An empty string is removed if
// there's no icon.
// IconURL returns the URL to the guild icon and auto detects a suitable type.
// An empty string is returned if there's no icon.
func (g Guild) IconURL() string {
return g.IconURLWithType(AutoImage)
}
// IconURLWithType returns the URL to the guild icon using the passed
// ImageType. An empty string is returned if there's no icon.
//
// Supported ImageTypes: PNG, JPEG, WebP, GIF
func (g Guild) IconURLWithType(t ImageType) string {
if g.Icon == "" {
return ""
}
base := "https://cdn.discordapp.com/icons/" +
g.ID.String() + "/" + g.Icon
if len(g.Icon) > 2 && g.Icon[:2] == "a_" {
return base + ".gif"
}
return base + ".png"
return "https://cdn.discordapp.com/icons/" + g.ID.String() + "/" + t.format(g.Icon)
}
// BannerURL returns the URL to the banner, which is the image on top of the
// channels list.
// channels list. This will always return a link to a PNG file.
func (g Guild) BannerURL() string {
return g.BannerURLWithType(PNGImage)
}
// BannerURLWithType returns the URL to the banner, which is the image on top
// of the channels list using the passed image type.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (g Guild) BannerURLWithType(t ImageType) string {
if g.Banner == "" {
return ""
}
return "https://cdn.discordapp.com/banners/" +
g.ID.String() + "/" + g.Banner + ".png"
g.ID.String() + "/" + t.format(g.Banner)
}
// SplashURL returns the URL to the guild splash, which is the invite page's
// background.
// background. This will always return a link to a PNG file.
func (g Guild) SplashURL() string {
return g.SplashURLWithType(PNGImage)
}
// SplashURLWithType returns the URL to the guild splash, which is the invite
// page's background, using the passed ImageType.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (g Guild) SplashURLWithType(t ImageType) string {
if g.Splash == "" {
return ""
}
return "https://cdn.discordapp.com/banners/" +
g.ID.String() + "/" + g.Splash + ".png"
return "https://cdn.discordapp.com/splashes/" +
g.ID.String() + "/" + t.format(g.Splash)
}
// DiscoverySplashURL returns the URL to the guild discovery splash.
// This will always return a link to a PNG file.
func (g Guild) DiscoverySplashURL() string {
return g.DiscoverySplashURLWithType(PNGImage)
}
// DiscoverySplashURLWithType returns the URL to the guild discovery splash,
// using the passed ImageType.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (g Guild) DiscoverySplashURLWithType(t ImageType) string {
if g.DiscoverySplash == "" {
return ""
}
return "https://cdn.discordapp.com/splashes/" +
g.ID.String() + "/" + t.format(g.DiscoverySplash)
}
// https://discord.com/developers/docs/resources/guild#guild-preview-object
type GuildPreview struct {
// ID is the guild id.
ID GuildID `json:"id"`
// Name is the guild name (2-100 characters).
Name string `json:"name"`
// Icon is the icon hash.
Icon Hash `json:"icon"`
// Splash is the splash hash.
Splash Hash `json:"splash"`
// DiscoverySplash is the discovery splash hash.
DiscoverySplash Hash `json:"discovery_splash"`
// Emojis are the custom guild emojis.
Emojis []Emoji `json:"emojis"`
// Features are the enabled guild features.
Features []GuildFeature `json:"guild_features"`
// ApproximateMembers is the approximate number of members in this guild.
ApproximateMembers uint64 `json:"approximate_member_count"`
// ApproximatePresences is the approximate number of online members in this
// guild.
ApproximatePresences uint64 `json:"approximate_presence_count"`
// Description is the description for the guild.
Description string `json:"description,omitempty"`
}
// IconURL returns the URL to the guild icon and auto detects a suitable type.
// An empty string is returned if there's no icon.
func (g GuildPreview) IconURL() string {
return g.IconURLWithType(AutoImage)
}
// IconURLWithType returns the URL to the guild icon using the passed
// ImageType. An empty string is returned if there's no icon.
//
// Supported ImageTypes: PNG, JPEG, WebP, GIF
func (g GuildPreview) IconURLWithType(t ImageType) string {
if g.Icon == "" {
return ""
}
return "https://cdn.discordapp.com/icons/" + g.ID.String() + "/" + t.format(g.Icon)
}
// SplashURL returns the URL to the guild splash, which is the invite page's
// background. This will always return a link to a PNG file.
func (g GuildPreview) SplashURL() string {
return g.SplashURLWithType(PNGImage)
}
// SplashURLWithType returns the URL to the guild splash, which is the invite
// page's background, using the passed ImageType.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (g GuildPreview) SplashURLWithType(t ImageType) string {
if g.Splash == "" {
return ""
}
return "https://cdn.discordapp.com/splashes/" +
g.ID.String() + "/" + t.format(g.Splash)
}
// DiscoverySplashURL returns the URL to the guild discovery splash.
// This will always return a link to a PNG file.
func (g GuildPreview) DiscoverySplashURL() string {
return g.DiscoverySplashURLWithType(PNGImage)
}
// DiscoverySplashURLWithType returns the URL to the guild discovery splash,
// using the passed ImageType.
//
// Supported ImageTypes: PNG, JPEG, WebP
func (g GuildPreview) DiscoverySplashURLWithType(t ImageType) string {
if g.DiscoverySplash == "" {
return ""
}
return "https://cdn.discordapp.com/splashes/" +
g.ID.String() + "/" + t.format(g.DiscoverySplash)
}
// https://discord.com/developers/docs/topics/permissions#role-object
type Role struct {
ID Snowflake `json:"id,string"`
Name string `json:"name"`
// ID is the role id.
ID RoleID `json:"id,string"`
// Name is the role name.
Name string `json:"name"`
Color Color `json:"color"`
Hoist bool `json:"hoist"` // if the role is separated
Position int `json:"position"`
// Color is the integer representation of hexadecimal color code.
Color Color `json:"color"`
// Hoist specifies if this role is pinned in the user listing.
Hoist bool `json:"hoist"`
// Position is the position of this role.
Position int `json:"position"`
Permissions Permissions `json:"permissions"`
// Permissions is the permission bit set.
Permissions Permissions `json:"permissions_new,string"`
Managed bool `json:"managed"`
// Manages specifies whether this role is managed by an integration.
Managed bool `json:"managed"`
// Mentionable specifies whether this role is mentionable.
Mentionable bool `json:"mentionable"`
}
// Mention returns the mention of the Role.
func (r Role) Mention() string {
return "<&" + r.ID.String() + ">"
return r.ID.Mention()
}
// https://discord.com/developers/docs/topics/gateway#presence-update
type Presence struct {
User User `json:"user"`
RoleIDs []Snowflake `json:"roles"`
// User is the user presence is being updated for.
User User `json:"user"`
// RoleIDs are the roles this user is in.
RoleIDs []RoleID `json:"roles"`
// These fields are only filled in gateway events, according to the
// documentation.
Nick string `json:"nick"`
GuildID Snowflake `json:"guild_id"`
// Game is null, or the user's current activity.
Game *Activity `json:"game"`
PremiumSince Timestamp `json:"premium_since,omitempty"`
// GuildID is the id of the guild
GuildID GuildID `json:"guild_id"`
Game *Activity `json:"game"`
// Status is either "idle", "dnd", "online", or "offline".
Status Status `json:"status"`
// Activities are the user's current activities.
Activities []Activity `json:"activities"`
Status Status `json:"status"`
// ClientStaus is the user's platform-dependent status.
//
// https://discord.com/developers/docs/topics/gateway#client-status-object
ClientStatus struct {
// Desktop is the user's status set for an active desktop (Windows,
// Linux, Mac) application session.
Desktop Status `json:"desktop,omitempty"`
Mobile Status `json:"mobile,omitempty"`
Web Status `json:"web,omitempty"`
// Mobile is the user's status set for an active mobile (iOS, Android)
// application session.
Mobile Status `json:"mobile,omitempty"`
// Web is the user's status set for an active web (browser, bot
// account) application session.
Web Status `json:"web,omitempty"`
} `json:"client_status"`
// Premium since specifies when the user started boosting the guild.
PremiumSince Timestamp `json:"premium_since,omitempty"`
// Nick is this users guild nickname (if one is set).
Nick string `json:"nick,omitempty"`
}
// https://discord.com/developers/docs/resources/guild#guild-member-object
//
// The field user won't be included in the member object attached to
// MESSAGE_CREATE and MESSAGE_UPDATE gateway events.
type Member struct {
User User `json:"user"`
Nick string `json:"nick,omitempty"`
RoleIDs []Snowflake `json:"roles"`
// User is the user this guild member represents.
User User `json:"user"`
// Nick is this users guild nickname.
Nick string `json:"nick,omitempty"`
// RoleIDs is an array of role object ids.
RoleIDs []RoleID `json:"roles"`
Joined Timestamp `json:"joined_at"`
// Joined specifies when the user joined the guild.
Joined Timestamp `json:"joined_at"`
// BoostedSince specifies when the user started boosting the guild.
BoostedSince Timestamp `json:"premium_since,omitempty"`
// Deaf specifies whether the user is deafened in voice channels.
Deaf bool `json:"deaf"`
// Mute specifies whether the user is muted in voice channels.
Mute bool `json:"mute"`
}
// Mention returns the mention of the role.
func (m Member) Mention() string {
return "<@!" + m.User.ID.String() + ">"
return m.User.Mention()
}
// https://discord.com/developers/docs/resources/guild#ban-object
type Ban struct {
// Reason is the reason for the ban.
Reason string `json:"reason,omitempty"`
User User `json:"user"`
// User is the banned user.
User User `json:"user"`
}
// https://discord.com/developers/docs/resources/guild#integration-object
type Integration struct {
ID Snowflake `json:"id"`
Name string `json:"name"`
Type Service `json:"type"`
// ID is the integration id.
ID IntegrationID `json:"id"`
// Name is the integration name.
Name string `json:"name"`
// Type is the integration type (twitch, youtube, etc).
Type Service `json:"type"`
// Enables specifies if the integration is enabled.
Enabled bool `json:"enabled"`
// Syncing specifies if the integration is syncing.
Syncing bool `json:"syncing"`
// used for subscribers
RoleID Snowflake `json:"role_id"`
// RoleID is the id that this integration uses for "subscribers".
RoleID RoleID `json:"role_id"`
ExpireBehavior ExpireBehavior `json:"expire_behavior"`
ExpireGracePeriod int `json:"expire_grace_period"`
// EnableEmoticons specifies whether emoticons should be synced for this
// integration (twitch only currently).
EnableEmoticons bool `json:"enable_emoticons,omitempty"`
User User `json:"user"`
// ExpireBehavior is the behavior of expiring subscribers
ExpireBehavior ExpireBehavior `json:"expire_behavior"`
// ExpireGracePeriod is the grace period (in days) before expiring
// subscribers.
ExpireGracePeriod int `json:"expire_grace_period"`
// User is the user for this integration.
User User `json:"user"`
// Account is the integration account information.
//
// https://discord.com/developers/docs/resources/guild#integration-account-object
Account struct {
ID string `json:"id"`
// ID is the id of the account.
ID string `json:"id"`
// Name is the name of the account.
Name string `json:"name"`
} `json:"account"`
// SyncedAt specifies when this integration was last synced.
SyncedAt Timestamp `json:"synced_at"`
}
type GuildEmbed struct {
Enabled bool `json:"enabled"`
ChannelID Snowflake `json:"channel_id,omitempty"`
// https://discord.com/developers/docs/resources/guild#guild-widget-object
type GuildWidget struct {
// Enabled specifies whether the widget is enabled.
Enabled bool `json:"enabled"`
// ChannelID is the widget channel id.
ChannelID ChannelID `json:"channel_id,omitempty"`
}
// DefaultMemberColor is the color used for members without colored roles.
var DefaultMemberColor Color = 0x0
// MemberColor computes the effective color of the Member, taking into account
// the role colors.
func MemberColor(guild Guild, member Member) Color {
var c = DefaultMemberColor
var pos int

View file

@ -4,11 +4,15 @@ import (
"github.com/diamondburned/arikawa/utils/json/enum"
)
// Guild.MaxPresences is 5000 when it's 0.
const DefaultMaxPresences = 5000
// Guild.MaxPresences is this value when it's 0.
// This happens because the Discord API sends JSON null, if the MaxPresences
// reach DefaultMaxPresences, which in turn will be serialized into 0.
const DefaultMaxPresences = 25000
// NitroBoost is the premium tier (Server Boost level).
type NitroBoost uint8
// https://discord.com/developers/docs/resources/guild#guild-object-premium-tier
const (
NoNitroLevel NitroBoost = iota
NitroLevel1
@ -16,46 +20,64 @@ const (
NitroLevel3
)
// MFALevel is the required MFA level for a guild.
type MFALevel uint8
// https://discord.com/developers/docs/resources/guild#guild-object-mfa-level
const (
NoMFA MFALevel = iota
ElevatedMFA
)
type SystemChannelFlags uint8
// https://discord.com/developers/docs/resources/guild#guild-object-system-channel-flags
const (
// SuppressJoinNotifications suppresses member join notifications.
SuppressJoinNotifications SystemChannelFlags = 1 << iota
// SuppressPremiumSubscriptions suppresses server boost notifications.
SuppressPremiumSubscriptions
)
type GuildFeature string
// https://discord.com/developers/docs/resources/guild#guild-object-guild-features
const (
// Guild has access to set an invite splash background
// InviteSplash is set, if the guild has access to set an invite splash
// background.
InviteSplash GuildFeature = "INVITE_SPLASH"
// Guild has access to set 384kbps bitrate in voice (previously VIP voice
// servers)
// VIPRegions is set, if the guild has access to set 384kbps bitrate in
// voice (previously VIP voice servers).
VIPRegions GuildFeature = "VIP_REGIONS"
// Guild has access to set a vanity URL
// VanityURL is set, if the guild has access to set a vanity URL.
VanityURL GuildFeature = "VANITY_URL"
// Guild is verified
// Verified is set, if the guild is verified.
Verified GuildFeature = "VERIFIED"
// Guild is partnered
// Partnered is set, if the guild is partnered.
Partnered GuildFeature = "PARTNERED"
// Guild is public
// Public is set, if the guild is public.
Public GuildFeature = "PUBLIC"
// Guild has access to use commerce features (i.e. create store channels)
// Commerce is set, if the guild has access to use commerce features
// (i.e. create store channels).
Commerce GuildFeature = "COMMERCE"
// Guild has access to create news channels
// News is set, if the guild has access to create news channels.
News GuildFeature = "NEWS"
// Guild is able to be discovered in the directory
// Discoverable is set, if the guild is able to be discovered in the
// directory.
Discoverable GuildFeature = "DISCOVERABLE"
// Guild is able to be featured in the directory
// Featurable is set, if the guild is able to be featured in the directory.
Featurable GuildFeature = "FEATURABLE"
// Guild has access to set an animated guild icon
// AnimatedIcon is set, if the guild has access to set an animated guild
// icon.
AnimatedIcon GuildFeature = "ANIMATED_ICON"
// Guild has access to set a guild banner image
// Banner is set, if the guild has access to set a guild banner image.
Banner GuildFeature = "BANNER"
)
// ExplicitFilter is the explicit content filter level of a guild.
type ExplicitFilter enum.Enum
// https://discord.com/developers/docs/resources/guild#guild-object-explicit-content-filter-level
var (
// NullExplicitFilter serialized to JSON null.
// This should only be used on nullable fields.
@ -82,6 +104,7 @@ func (f ExplicitFilter) MarshalJSON() ([]byte, error) {
// Notification is the default message notification level of a guild.
type Notification enum.Enum
// https://discord.com/developers/docs/resources/guild#guild-object-default-message-notification-level
var (
// NullNotification serialized to JSON null.
// This should only be used on nullable fields.
@ -104,6 +127,7 @@ func (n Notification) MarshalJSON() ([]byte, error) { return enum.ToJSON(enum.En
// Verification is the verification level required for a guild.
type Verification enum.Enum
// https://discord.com/developers/docs/resources/guild#guild-object-verification-level
var (
// NullVerification serialized to JSON null.
// This should only be used on nullable fields.
@ -143,6 +167,7 @@ const (
// ExpireBehavior is the integration expire behavior that regulates what happens, if a subscriber expires.
type ExpireBehavior uint8
// https://discord.com/developers/docs/resources/guild#integration-object-integration-expire-behaviors
var (
// RemoveRole removes the role of the subscriber.
RemoveRole ExpireBehavior = 0

View file

@ -1,22 +1,38 @@
package discord
// Invite represents a code that when used, adds a user to a guild or group
// DM channel.
//
// https://discord.com/developers/docs/resources/invite#invite-object
type Invite struct {
Code string `json:"code"`
Channel Channel `json:"channel"` // partial
Guild *Guild `json:"guild,omitempty"` // partial
Inviter *User `json:"inviter,omitempty"`
// Code is the invite code (unique ID).
Code string `json:"code"`
// Guild is the partial guild this invite is for.
Guild *Guild `json:"guild,omitempty"`
// Channel is the partial channel this invite is for.
Channel Channel `json:"channel"`
// Inviter is the user who created the invite
Inviter *User `json:"inviter,omitempty"`
ApproxMembers uint `json:"approximate_members_count,omitempty"`
Target *User `json:"target_user,omitempty"` // partial
// Target is the target user for this invite.
Target *User `json:"target_user,omitempty"`
// Target type is the type of user target for this invite.
TargetType InviteUserType `json:"target_user_type,omitempty"`
// Only available if Target is
ApproxPresences uint `json:"approximate_presence_count,omitempty"`
// ApproximatePresences is the approximate count of online members (only
// present when Target is set).
ApproximatePresences uint `json:"approximate_presence_count,omitempty"`
// ApproximateMembers is the approximate count of total members
ApproximateMembers uint `json:"approximate_member_count,omitempty"`
InviteMetadata // only available when fetching ChannelInvites or GuildInvites
// InviteMetadata contains extra information about the invite.
// So far, this field is only available when fetching Channel- or
// GuildInvites. Additionally the Uses field is filled when getting the
// VanityURL of a guild.
InviteMetadata
}
// https://discord.com/developers/docs/resources/invite#invite-object-target-user-types
type InviteUserType uint8
const (
@ -25,6 +41,8 @@ const (
)
// Extra information about an invite, will extend the invite object.
//
// https://discord.com/developers/docs/resources/invite#invite-metadata-object
type InviteMetadata struct {
// Number of times this invite has been used
Uses int `json:"uses"`

View file

@ -1,12 +1,16 @@
package discord
import "github.com/diamondburned/arikawa/utils/json/enum"
import (
"fmt"
"github.com/diamondburned/arikawa/utils/json/enum"
)
type Message struct {
ID Snowflake `json:"id,string"`
ID MessageID `json:"id,string"`
Type MessageType `json:"type"`
ChannelID Snowflake `json:"channel_id,string"`
GuildID Snowflake `json:"guild_id,string,omitempty"`
ChannelID ChannelID `json:"channel_id,string"`
GuildID GuildID `json:"guild_id,string,omitempty"`
// The author object follows the structure of the user object, but is only
// a valid user in the case where the message is generated by a user or bot
@ -29,8 +33,8 @@ type Message struct {
// text-based guild channels.
Mentions []GuildUser `json:"mentions"`
MentionRoleIDs []Snowflake `json:"mention_roles"`
MentionEveryone bool `json:"mention_everyone"`
MentionRoleIDs []RoleID `json:"mention_roles"`
MentionEveryone bool `json:"mention_everyone"`
// Not all channel mentions in a message will appear in mention_channels.
MentionChannels []ChannelMention `json:"mention_channels,omitempty"`
@ -43,7 +47,7 @@ type Message struct {
// Used for validating a message was sent
Nonce string `json:"nonce,omitempty"`
WebhookID Snowflake `json:"webhook_id,string,omitempty"`
WebhookID WebhookID `json:"webhook_id,string,omitempty"`
Activity *MessageActivity `json:"activity,omitempty"`
Application *MessageApplication `json:"application,omitempty"`
Reference *MessageReference `json:"message_reference,omitempty"`
@ -53,14 +57,15 @@ type Message struct {
// URL generates a Discord client URL to the message. If the message doesn't
// have a GuildID, it will generate a URL with the guild "@me".
func (m Message) URL() string {
var head = "https://discordapp.com/channels/"
var tail = "/" + m.ChannelID.String() + "/" + m.ID.String()
if !m.GuildID.Valid() {
return head + "@me" + tail
var guildID = "@me"
if m.GuildID.IsValid() {
guildID = m.GuildID.String()
}
return head + m.GuildID.String() + tail
return fmt.Sprintf(
"https://discord.com/channels/%s/%s/%s",
guildID, m.ChannelID.String(), m.ID.String(),
)
}
type MessageType uint8
@ -95,8 +100,8 @@ var (
)
type ChannelMention struct {
ChannelID Snowflake `json:"id,string"`
GuildID Snowflake `json:"guild_id,string"`
ChannelID ChannelID `json:"id,string"`
GuildID GuildID `json:"guild_id,string"`
ChannelType ChannelType `json:"type"`
ChannelName string `json:"name"`
}
@ -127,29 +132,29 @@ const (
//
type MessageApplication struct {
ID Snowflake `json:"id,string"`
CoverID string `json:"cover_image,omitempty"`
Description string `json:"description"`
Icon string `json:"icon"`
Name string `json:"name"`
ID AppID `json:"id,string"`
CoverID string `json:"cover_image,omitempty"`
Description string `json:"description"`
Icon string `json:"icon"`
Name string `json:"name"`
}
//
type MessageReference struct {
ChannelID Snowflake `json:"channel_id,string"`
ChannelID ChannelID `json:"channel_id,string"`
// Field might not be provided
MessageID Snowflake `json:"message_id,string,omitempty"`
GuildID Snowflake `json:"guild_id,string,omitempty"`
MessageID MessageID `json:"message_id,string,omitempty"`
GuildID GuildID `json:"guild_id,string,omitempty"`
}
//
type Attachment struct {
ID Snowflake `json:"id,string"`
Filename string `json:"filename"`
Size uint64 `json:"size"`
ID AttachmentID `json:"id,string"`
Filename string `json:"filename"`
Size uint64 `json:"size"`
URL URL `json:"url"`
Proxy URL `json:"proxy_url"`

View file

@ -80,15 +80,15 @@ func (e *Embed) Validate() error {
}
if len(e.Title) > 256 {
return &ErrOverbound{len(e.Title), 256, "Title"}
return &ErrOverbound{len(e.Title), 256, "title"}
}
if len(e.Description) > 2048 {
return &ErrOverbound{len(e.Description), 2048, "Description"}
return &ErrOverbound{len(e.Description), 2048, "description"}
}
if len(e.Fields) > 25 {
return &ErrOverbound{len(e.Fields), 25, "Fields"}
return &ErrOverbound{len(e.Fields), 25, "fields"}
}
var sum = 0 +
@ -97,7 +97,7 @@ func (e *Embed) Validate() error {
if e.Footer != nil {
if len(e.Footer.Text) > 2048 {
return &ErrOverbound{len(e.Footer.Text), 2048, "Footer text"}
return &ErrOverbound{len(e.Footer.Text), 2048, "footer text"}
}
sum += len(e.Footer.Text)
@ -105,7 +105,7 @@ func (e *Embed) Validate() error {
if e.Author != nil {
if len(e.Author.Name) > 256 {
return &ErrOverbound{len(e.Author.Name), 256, "Author name"}
return &ErrOverbound{len(e.Author.Name), 256, "author name"}
}
sum += len(e.Author.Name)
@ -126,7 +126,7 @@ func (e *Embed) Validate() error {
}
if sum > 6000 {
return &ErrOverbound{sum, 6000, "Sum of all characters"}
return &ErrOverbound{sum, 6000, "sum of all characters"}
}
return nil

View file

@ -126,7 +126,7 @@ func CalcOverwrites(guild Guild, channel Channel, member Member) Permissions {
var perm Permissions
for _, role := range guild.Roles {
if role.ID == guild.ID {
if role.ID == RoleID(guild.ID) {
perm |= role.Permissions
break
}
@ -146,7 +146,7 @@ func CalcOverwrites(guild Guild, channel Channel, member Member) Permissions {
}
for _, overwrite := range channel.Permissions {
if overwrite.ID == guild.ID {
if GuildID(overwrite.ID) == guild.ID {
perm &= ^overwrite.Deny
perm |= overwrite.Allow
break
@ -157,7 +157,7 @@ func CalcOverwrites(guild Guild, channel Channel, member Member) Permissions {
for _, overwrite := range channel.Permissions {
for _, id := range member.RoleIDs {
if id == overwrite.ID && overwrite.Type == "role" {
if id == RoleID(overwrite.ID) && overwrite.Type == "role" {
deny |= overwrite.Deny
allow |= overwrite.Allow
break
@ -169,7 +169,7 @@ func CalcOverwrites(guild Guild, channel Channel, member Member) Permissions {
perm |= allow
for _, overwrite := range channel.Permissions {
if overwrite.ID == member.User.ID {
if UserID(overwrite.ID) == member.User.ID {
perm &= ^overwrite.Deny
perm |= overwrite.Allow
break

View file

@ -6,70 +6,79 @@ import (
"time"
)
// DiscordEpoch is the Discord epoch constant in time.Duration (nanoseconds)
// Epoch is the Discord epoch constant in time.Duration (nanoseconds)
// since Unix epoch.
const DiscordEpoch = 1420070400000 * time.Millisecond
const Epoch = 1420070400000 * time.Millisecond
// DurationSinceDiscordEpoch returns the duration from the Discord epoch to
// current.
func DurationSinceDiscordEpoch(t time.Time) time.Duration {
return time.Duration(t.UnixNano()) - DiscordEpoch
// DurationSinceEpoch returns the duration from the Discord epoch to current.
func DurationSinceEpoch(t time.Time) time.Duration {
return time.Duration(t.UnixNano()) - Epoch
}
type Snowflake int64
type Snowflake uint64
// NullSnowflake gets encoded into a null. This is used for
// optional and nullable snowflake fields.
const NullSnowflake Snowflake = -1
const NullSnowflake = ^Snowflake(0)
func NewSnowflake(t time.Time) Snowflake {
return Snowflake((DurationSinceDiscordEpoch(t) / time.Millisecond) << 22)
return Snowflake((DurationSinceEpoch(t) / time.Millisecond) << 22)
}
func ParseSnowflake(sf string) (Snowflake, error) {
i, err := strconv.ParseInt(sf, 10, 64)
if sf == "null" {
return NullSnowflake, nil
}
u, err := strconv.ParseUint(sf, 10, 64)
if err != nil {
return 0, err
}
return Snowflake(i), nil
return Snowflake(u), nil
}
func (s *Snowflake) UnmarshalJSON(v []byte) error {
id := strings.Trim(string(v), `"`)
if id == "null" {
// Use -1 for null.
*s = -1
return nil
}
i, err := strconv.ParseInt(id, 10, 64)
p, err := ParseSnowflake(strings.Trim(string(v), `"`))
if err != nil {
return err
}
*s = Snowflake(i)
*s = p
return nil
}
func (s Snowflake) MarshalJSON() ([]byte, error) {
if s < 1 {
// This includes 0 and null, because MarshalJSON does not dictate when a
// value gets omitted.
if !s.IsValid() {
return []byte("null"), nil
} else {
return []byte(`"` + strconv.FormatInt(int64(s), 10) + `"`), nil
}
}
// String returns the ID, or nothing if the snowflake isn't valid.
func (s Snowflake) String() string {
// Check if negative.
if !s.IsValid() {
return ""
}
return strconv.FormatUint(uint64(s), 10)
}
func (s Snowflake) Valid() bool {
return uint64(s) > 0
// IsValid returns whether or not the snowflake is valid.
func (s Snowflake) IsValid() bool {
return !(int64(s) == 0 || s == NullSnowflake)
}
// IsNull returns whether or not the snowflake is null.
func (s Snowflake) IsNull() bool {
return s == NullSnowflake
}
func (s Snowflake) Time() time.Time {
unixnano := ((time.Duration(s) >> 22) * time.Millisecond) + DiscordEpoch
unixnano := time.Duration(s>>22)*time.Millisecond + Epoch
return time.Unix(0, int64(unixnano))
}
@ -84,3 +93,160 @@ func (s Snowflake) PID() uint8 {
func (s Snowflake) Increment() uint16 {
return uint16(s & 0xFFF)
}
type AppID Snowflake
const NullAppID = AppID(NullSnowflake)
func (s AppID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *AppID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s AppID) String() string { return Snowflake(s).String() }
func (s AppID) IsValid() bool { return Snowflake(s).IsValid() }
func (s AppID) IsNull() bool { return Snowflake(s).IsNull() }
func (s AppID) Time() time.Time { return Snowflake(s).Time() }
func (s AppID) Worker() uint8 { return Snowflake(s).Worker() }
func (s AppID) PID() uint8 { return Snowflake(s).PID() }
func (s AppID) Increment() uint16 { return Snowflake(s).Increment() }
type AttachmentID Snowflake
const NullAttachmentID = AttachmentID(NullSnowflake)
func (s AttachmentID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *AttachmentID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s AttachmentID) String() string { return Snowflake(s).String() }
func (s AttachmentID) IsValid() bool { return Snowflake(s).IsValid() }
func (s AttachmentID) IsNull() bool { return Snowflake(s).IsNull() }
func (s AttachmentID) Time() time.Time { return Snowflake(s).Time() }
func (s AttachmentID) Worker() uint8 { return Snowflake(s).Worker() }
func (s AttachmentID) PID() uint8 { return Snowflake(s).PID() }
func (s AttachmentID) Increment() uint16 { return Snowflake(s).Increment() }
type AuditLogEntryID Snowflake
const NullAuditLogEntryID = AuditLogEntryID(NullSnowflake)
func (s AuditLogEntryID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *AuditLogEntryID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s AuditLogEntryID) String() string { return Snowflake(s).String() }
func (s AuditLogEntryID) IsValid() bool { return Snowflake(s).IsValid() }
func (s AuditLogEntryID) IsNull() bool { return Snowflake(s).IsNull() }
func (s AuditLogEntryID) Time() time.Time { return Snowflake(s).Time() }
func (s AuditLogEntryID) Worker() uint8 { return Snowflake(s).Worker() }
func (s AuditLogEntryID) PID() uint8 { return Snowflake(s).PID() }
func (s AuditLogEntryID) Increment() uint16 { return Snowflake(s).Increment() }
type ChannelID Snowflake
const NullChannelID = ChannelID(NullSnowflake)
func (s ChannelID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *ChannelID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s ChannelID) String() string { return Snowflake(s).String() }
func (s ChannelID) IsValid() bool { return Snowflake(s).IsValid() }
func (s ChannelID) IsNull() bool { return Snowflake(s).IsNull() }
func (s ChannelID) Time() time.Time { return Snowflake(s).Time() }
func (s ChannelID) Worker() uint8 { return Snowflake(s).Worker() }
func (s ChannelID) PID() uint8 { return Snowflake(s).PID() }
func (s ChannelID) Increment() uint16 { return Snowflake(s).Increment() }
func (s ChannelID) Mention() string { return "<#" + s.String() + ">" }
type EmojiID Snowflake
const NullEmojiID = EmojiID(NullSnowflake)
func (s EmojiID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *EmojiID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s EmojiID) String() string { return Snowflake(s).String() }
func (s EmojiID) IsValid() bool { return Snowflake(s).IsValid() }
func (s EmojiID) IsNull() bool { return Snowflake(s).IsNull() }
func (s EmojiID) Time() time.Time { return Snowflake(s).Time() }
func (s EmojiID) Worker() uint8 { return Snowflake(s).Worker() }
func (s EmojiID) PID() uint8 { return Snowflake(s).PID() }
func (s EmojiID) Increment() uint16 { return Snowflake(s).Increment() }
type IntegrationID Snowflake
const NullIntegrationID = IntegrationID(NullSnowflake)
func (s IntegrationID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *IntegrationID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s IntegrationID) String() string { return Snowflake(s).String() }
func (s IntegrationID) IsValid() bool { return Snowflake(s).IsValid() }
func (s IntegrationID) IsNull() bool { return Snowflake(s).IsNull() }
func (s IntegrationID) Time() time.Time { return Snowflake(s).Time() }
func (s IntegrationID) Worker() uint8 { return Snowflake(s).Worker() }
func (s IntegrationID) PID() uint8 { return Snowflake(s).PID() }
func (s IntegrationID) Increment() uint16 { return Snowflake(s).Increment() }
type GuildID Snowflake
const NullGuildID = GuildID(NullSnowflake)
func (s GuildID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *GuildID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s GuildID) String() string { return Snowflake(s).String() }
func (s GuildID) IsValid() bool { return Snowflake(s).IsValid() }
func (s GuildID) IsNull() bool { return Snowflake(s).IsNull() }
func (s GuildID) Time() time.Time { return Snowflake(s).Time() }
func (s GuildID) Worker() uint8 { return Snowflake(s).Worker() }
func (s GuildID) PID() uint8 { return Snowflake(s).PID() }
func (s GuildID) Increment() uint16 { return Snowflake(s).Increment() }
type MessageID Snowflake
const NullMessageID = MessageID(NullSnowflake)
func (s MessageID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *MessageID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s MessageID) String() string { return Snowflake(s).String() }
func (s MessageID) IsValid() bool { return Snowflake(s).IsValid() }
func (s MessageID) IsNull() bool { return Snowflake(s).IsNull() }
func (s MessageID) Time() time.Time { return Snowflake(s).Time() }
func (s MessageID) Worker() uint8 { return Snowflake(s).Worker() }
func (s MessageID) PID() uint8 { return Snowflake(s).PID() }
func (s MessageID) Increment() uint16 { return Snowflake(s).Increment() }
type RoleID Snowflake
const NullRoleID = RoleID(NullSnowflake)
func (s RoleID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *RoleID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s RoleID) String() string { return Snowflake(s).String() }
func (s RoleID) IsValid() bool { return Snowflake(s).IsValid() }
func (s RoleID) IsNull() bool { return Snowflake(s).IsNull() }
func (s RoleID) Time() time.Time { return Snowflake(s).Time() }
func (s RoleID) Worker() uint8 { return Snowflake(s).Worker() }
func (s RoleID) PID() uint8 { return Snowflake(s).PID() }
func (s RoleID) Increment() uint16 { return Snowflake(s).Increment() }
func (s RoleID) Mention() string { return "<@&" + s.String() + ">" }
type UserID Snowflake
const NullUserID = UserID(NullSnowflake)
func (s UserID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *UserID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s UserID) String() string { return Snowflake(s).String() }
func (s UserID) IsValid() bool { return Snowflake(s).IsValid() }
func (s UserID) IsNull() bool { return Snowflake(s).IsNull() }
func (s UserID) Time() time.Time { return Snowflake(s).Time() }
func (s UserID) Worker() uint8 { return Snowflake(s).Worker() }
func (s UserID) PID() uint8 { return Snowflake(s).PID() }
func (s UserID) Increment() uint16 { return Snowflake(s).Increment() }
func (s UserID) Mention() string { return "<@" + s.String() + ">" }
type WebhookID Snowflake
const NullWebhookID = WebhookID(NullSnowflake)
func (s WebhookID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() }
func (s *WebhookID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) }
func (s WebhookID) String() string { return Snowflake(s).String() }
func (s WebhookID) IsValid() bool { return Snowflake(s).IsValid() }
func (s WebhookID) IsNull() bool { return Snowflake(s).IsNull() }
func (s WebhookID) Time() time.Time { return Snowflake(s).Time() }
func (s WebhookID) Worker() uint8 { return Snowflake(s).Worker() }
func (s WebhookID) PID() uint8 { return Snowflake(s).PID() }
func (s WebhookID) Increment() uint16 { return Snowflake(s).Increment() }

View file

@ -36,6 +36,28 @@ func TestSnowflake(t *testing.T) {
}
})
t.Run("IsValid", func(t *testing.T) {
t.Run("0", func(t *testing.T) {
if Snowflake(0).IsValid() {
t.Fatal("0 isn't a valid Snowflake")
}
})
t.Run("null", func(t *testing.T) {
if NullSnowflake.IsValid() {
t.Fatal("NullSnowflake isn't a valid Snowflake")
}
})
t.Run("valid", func(t *testing.T) {
var testFlake Snowflake = 123
if !testFlake.IsValid() {
t.Fatal(testFlake, "is a valid Snowflake")
}
})
})
t.Run("new", func(t *testing.T) {
if s := NewSnowflake(expect); !s.Time().Equal(expect) {
t.Fatal("Unexpected new snowflake from expected time:", s)

View file

@ -7,7 +7,7 @@ import (
"time"
)
// Timestamp has a valid zero-value, which can be checked using the Valid()
// Timestamp has a valid zero-value, which can be checked using the IsValid()
// method. This is useful for optional timestamps such as EditedTimestamp.
type Timestamp time.Time
@ -45,14 +45,14 @@ func (t *Timestamp) UnmarshalJSON(v []byte) error {
// MarshalJSON returns null if Timestamp is not valid (zero). It returns the
// time formatted in RFC3339 otherwise.
func (t Timestamp) MarshalJSON() ([]byte, error) {
if !t.Valid() {
if !t.IsValid() {
return []byte("null"), nil
}
return []byte(`"` + t.Format(TimestampFormat) + `"`), nil
}
func (t Timestamp) Valid() bool {
func (t Timestamp) IsValid() bool {
return !t.Time().IsZero()
}

View file

@ -1,4 +1,34 @@
package discord
import "strings"
type ImageType string
const (
// AutoImage chooses automatically between a PNG and GIF.
AutoImage ImageType = "auto"
// JPEGImage is the JPEG image type.
JPEGImage ImageType = ".jpeg"
// PNGImage is the PNG image type.
PNGImage ImageType = ".png"
// WebPImage is the WebP image type.
WebPImage ImageType = ".webp"
// GIFImage is the GIF image type.
GIFImage ImageType = ".gif"
)
func (t ImageType) format(name string) string {
if t == AutoImage {
if strings.HasPrefix(name, "a_") {
return name + ".gif"
}
return name + ".png"
}
return name + string(t)
}
type URL = string
type Hash = string

View file

@ -1,16 +1,14 @@
package discord
import "strings"
// DefaultAvatarURL is the link to the default green avatar on Discord. It's
// returned from AvatarURL() if the user doesn't have an avatar.
var DefaultAvatarURL = "https://discordapp.com/assets/dd4dbc0016779df1378e7812eabaa04d.png"
import (
"strconv"
)
type User struct {
ID Snowflake `json:"id,string"`
Username string `json:"username"`
Discriminator string `json:"discriminator"`
Avatar Hash `json:"avatar"`
ID UserID `json:"id,string"`
Username string `json:"username"`
Discriminator string `json:"discriminator"`
Avatar Hash `json:"avatar"`
// These fields may be omitted
@ -29,22 +27,36 @@ type User struct {
}
func (u User) Mention() string {
return "<@" + u.ID.String() + ">"
return u.ID.Mention()
}
// AvatarURL returns the URL of the Avatar Image. It automatically detects a
// suitable type.
func (u User) AvatarURL() string {
return u.AvatarURLWithType(AutoImage)
}
// AvatarURLWithType returns the URL of the Avatar Image using the passed type.
// If the user has no Avatar, his default avatar will be returned. This
// requires ImageType Auto or PNG
//
// Supported Image Types: PNG, JPEG, WebP, GIF (read above for caveat)
func (u User) AvatarURLWithType(t ImageType) string {
if u.Avatar == "" {
return DefaultAvatarURL
if t != PNGImage && t != AutoImage {
return ""
}
disc, err := strconv.Atoi(u.Discriminator)
if err != nil { // this should never happen
return ""
}
picNo := strconv.Itoa(disc % 5)
return "https://cdn.discordapp.com/embed/avatars/" + picNo + ".png"
}
base := "https://cdn.discordapp.com"
base += "/avatars/" + u.ID.String() + "/" + u.Avatar
if strings.HasPrefix(u.Avatar, "a_") {
return base + ".gif"
} else {
return base + ".png"
}
return "https://cdn.discordapp.com/avatars/" + u.ID.String() + "/" + t.format(u.Avatar)
}
type UserFlags uint32
@ -52,8 +64,8 @@ type UserFlags uint32
const NoFlag UserFlags = 0
const (
DiscordEmployee UserFlags = 1 << iota
DiscordPartner
Employee UserFlags = 1 << iota
Partner
HypeSquadEvents
BugHunterLvl1
_
@ -81,9 +93,9 @@ const (
)
type Connection struct {
ID Snowflake `json:"id"`
Name string `json:"name"`
Type Service `json:"type"`
ID string `json:"id"`
Name string `json:"name"`
Type Service `json:"type"`
Revoked bool `json:"revoked"`
Verified bool `json:"verified"`
@ -124,10 +136,10 @@ type Activity struct {
CreatedAt UnixTimestamp `json:"created_at,omitempty"`
Timestamps *ActivityTimestamp `json:"timestamps,omitempty"`
ApplicationID Snowflake `json:"application_id,omitempty"`
Details string `json:"details,omitempty"`
State string `json:"state,omitempty"` // party status
Emoji *Emoji `json:"emoji,omitempty"`
ApplicationID AppID `json:"application_id,omitempty"`
Details string `json:"details,omitempty"`
State string `json:"state,omitempty"` // party status
Emoji *Emoji `json:"emoji,omitempty"`
Party *ActivityParty `json:"party,omitempty"`
Assets *ActivityAssets `json:"assets,omitempty"`
@ -150,7 +162,8 @@ const (
StreamingActivity
// Listening to $name
ListeningActivity
_
// Watching $name
WatchingActivity
// $emoji $state
CustomActivity
)
@ -188,3 +201,22 @@ type ActivitySecrets struct {
Spectate string `json:"spectate,omitempty"`
Match string `json:"match,omitempty"`
}
// A Relationship between the logged in user and the user in the struct. This
// struct is undocumented.
type Relationship struct {
UserID UserID `json:"id"`
User User `json:"user"`
Type RelationshipType `json:"type"`
}
// RelationshipType is an enum for a relationship.
type RelationshipType uint8
const (
_ RelationshipType = iota
FriendRelationship
BlockedRelationship
IncomingFriendRequest
SentFriendRequest
)

View file

@ -2,10 +2,10 @@ package discord
type VoiceState struct {
// GuildID isn't available from the Guild struct.
GuildID Snowflake `json:"guild_id,string"`
GuildID GuildID `json:"guild_id,string"`
ChannelID Snowflake `json:"channel_id,string"`
UserID Snowflake `json:"user_id,string"`
ChannelID ChannelID `json:"channel_id,string"`
UserID UserID `json:"user_id,string"`
Member *Member `json:"member,omitempty"`
SessionID string `json:"session_id"`

View file

@ -1,12 +1,12 @@
package discord
type Webhook struct {
ID Snowflake `json:"id"`
ID WebhookID `json:"id"`
Type WebhookType `json:"type"`
User User `json:"user"` // creator
GuildID Snowflake `json:"guild_id,omitempty"`
ChannelID Snowflake `json:"channel_id"`
GuildID GuildID `json:"guild_id,omitempty"`
ChannelID ChannelID `json:"channel_id"`
Name string `json:"name"`
Avatar Hash `json:"avatar"`

View file

@ -15,11 +15,18 @@ func (g *Gateway) Identify() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.IdentifyCtx(ctx)
}
func (g *Gateway) IdentifyCtx(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, g.WSTimeout)
defer cancel()
if err := g.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for identify()")
return errors.Wrap(err, "can't wait for identify()")
}
return g.Send(IdentifyOP, g.Identifier)
return g.SendCtx(ctx, IdentifyOP, g.Identifier)
}
type ResumeData struct {
@ -31,6 +38,15 @@ type ResumeData struct {
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.ResumeCtx(ctx)
}
// ResumeCtx sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) ResumeCtx(ctx context.Context) error {
var (
ses = g.SessionID
seq = g.Sequence.Get()
@ -40,7 +56,7 @@ func (g *Gateway) Resume() error {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
return g.SendCtx(ctx, ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
@ -51,31 +67,57 @@ func (g *Gateway) Resume() error {
type HeartbeatData int
func (g *Gateway) Heartbeat() error {
return g.Send(HeartbeatOP, g.Sequence.Get())
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.HeartbeatCtx(ctx)
}
func (g *Gateway) HeartbeatCtx(ctx context.Context) error {
return g.SendCtx(ctx, HeartbeatOP, g.Sequence.Get())
}
type RequestGuildMembersData struct {
GuildID []discord.Snowflake `json:"guild_id"`
UserIDs []discord.Snowflake `json:"user_ids,omitempty"`
GuildID []discord.GuildID `json:"guild_id"`
UserIDs []discord.UserID `json:"user_ids,omitempty"`
Query string `json:"query,omitempty"`
Query string `json:"query"`
Limit uint `json:"limit"`
Presences bool `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
func (g *Gateway) RequestGuildMembers(data RequestGuildMembersData) error {
return g.Send(RequestGuildMembersOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.RequestGuildMembersCtx(ctx, data)
}
func (g *Gateway) RequestGuildMembersCtx(
ctx context.Context, data RequestGuildMembersData) error {
return g.SendCtx(ctx, RequestGuildMembersOP, data)
}
type UpdateVoiceStateData struct {
GuildID discord.Snowflake `json:"guild_id"`
ChannelID discord.Snowflake `json:"channel_id"` // nullable
GuildID discord.GuildID `json:"guild_id"`
ChannelID discord.ChannelID `json:"channel_id"` // nullable
SelfMute bool `json:"self_mute"`
SelfDeaf bool `json:"self_deaf"`
}
func (g *Gateway) UpdateVoiceState(data UpdateVoiceStateData) error {
return g.Send(VoiceStateUpdateOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateVoiceStateCtx(ctx, data)
}
func (g *Gateway) UpdateVoiceStateCtx(
ctx context.Context, data UpdateVoiceStateData) error {
return g.SendCtx(ctx, VoiceStateUpdateOP, data)
}
type UpdateStatusData struct {
@ -90,19 +132,33 @@ type UpdateStatusData struct {
}
func (g *Gateway) UpdateStatus(data UpdateStatusData) error {
return g.Send(StatusUpdateOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.UpdateStatusCtx(ctx, data)
}
func (g *Gateway) UpdateStatusCtx(ctx context.Context, data UpdateStatusData) error {
return g.SendCtx(ctx, StatusUpdateOP, data)
}
// Undocumented
type GuildSubscribeData struct {
Typing bool `json:"typing"`
Activities bool `json:"activities"`
GuildID discord.Snowflake `json:"guild_id"`
Typing bool `json:"typing"`
Activities bool `json:"activities"`
GuildID discord.GuildID `json:"guild_id"`
// Channels is not documented. It's used to fetch the right members sidebar.
Channels map[discord.Snowflake][][2]int `json:"channels"`
Channels map[discord.ChannelID][][2]int `json:"channels,omitempty"`
}
func (g *Gateway) GuildSubscribe(data GuildSubscribeData) error {
return g.Send(GuildSubscriptionsOP, data)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.GuildSubscribeCtx(ctx, data)
}
func (g *Gateway) GuildSubscribeCtx(ctx context.Context, data GuildSubscribeData) error {
return g.SendCtx(ctx, GuildSubscriptionsOP, data)
}

View file

@ -20,21 +20,27 @@ type (
// https://discordapp.com/developers/docs/topics/gateway#channels
type (
ChannelCreateEvent discord.Channel
ChannelUpdateEvent discord.Channel
ChannelDeleteEvent discord.Channel
ChannelCreateEvent struct {
discord.Channel
}
ChannelUpdateEvent struct {
discord.Channel
}
ChannelDeleteEvent struct {
discord.Channel
}
ChannelPinsUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id,omitempty"`
ChannelID discord.Snowflake `json:"channel_id,omitempty"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
ChannelID discord.ChannelID `json:"channel_id,omitempty"`
LastPin discord.Timestamp `json:"timestamp,omitempty"`
}
ChannelUnreadUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
GuildID discord.GuildID `json:"guild_id"`
ChannelUnreadUpdates []struct {
ID discord.Snowflake `json:"id"`
LastMessageID discord.Snowflake `json:"last_message_id"`
ID discord.ChannelID `json:"id"`
LastMessageID discord.MessageID `json:"last_message_id"`
}
}
)
@ -51,69 +57,75 @@ type (
VoiceStates []discord.VoiceState `json:"voice_states,omitempty"`
Members []discord.Member `json:"members,omitempty"`
Channels []discord.Channel `json:"channel,omitempty"`
Channels []discord.Channel `json:"channels,omitempty"`
Presences []discord.Presence `json:"presences,omitempty"`
}
GuildUpdateEvent discord.Guild
GuildUpdateEvent struct {
discord.Guild
}
GuildDeleteEvent struct {
ID discord.Snowflake `json:"id"`
ID discord.GuildID `json:"id"`
// Unavailable if false == removed
Unavailable bool `json:"unavailable"`
}
GuildBanAddEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
GuildID discord.GuildID `json:"guild_id"`
User discord.User `json:"user"`
}
GuildBanRemoveEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
GuildID discord.GuildID `json:"guild_id"`
User discord.User `json:"user"`
}
GuildEmojisUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Emojis []discord.Emoji `json:"emoji"`
GuildID discord.GuildID `json:"guild_id"`
Emojis []discord.Emoji `json:"emoji"`
}
GuildIntegrationsUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
GuildID discord.GuildID `json:"guild_id"`
}
GuildMemberAddEvent struct {
discord.Member
GuildID discord.Snowflake `json:"guild_id"`
GuildID discord.GuildID `json:"guild_id"`
}
GuildMemberRemoveEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
User discord.User `json:"user"`
GuildID discord.GuildID `json:"guild_id"`
User discord.User `json:"user"`
}
GuildMemberUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
RoleIDs []discord.Snowflake `json:"roles"`
User discord.User `json:"user"`
Nick string `json:"nick"`
GuildID discord.GuildID `json:"guild_id"`
RoleIDs []discord.RoleID `json:"roles"`
User discord.User `json:"user"`
Nick string `json:"nick"`
}
// GuildMembersChunkEvent is sent when Guild Request Members is called.
GuildMembersChunkEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Members []discord.Member `json:"members"`
GuildID discord.GuildID `json:"guild_id"`
Members []discord.Member `json:"members"`
ChunkIndex int `json:"chunk_index"`
ChunkCount int `json:"chunk_count"`
// Whatever's not found goes here
NotFound []string `json:"not_found,omitempty"`
// Only filled if requested
Presences []discord.Presence `json:"presences,omitempty"`
Nonce string `json:"nonce,omitempty"`
}
// GuildMemberListUpdate is an undocumented event. It's received when the
// client sends over GuildSubscriptions with the Channels field used.
// The State package does not handle this event.
GuildMemberListUpdate struct {
ID string `json:"id"`
GuildID discord.Snowflake `json:"guild_id"`
MemberCount uint64 `json:"member_count"`
OnlineCount uint64 `json:"online_count"`
ID string `json:"id"`
GuildID discord.GuildID `json:"guild_id"`
MemberCount uint64 `json:"member_count"`
OnlineCount uint64 `json:"online_count"`
// Groups is all the visible role sections.
Groups []GuildMemberListGroup `json:"groups"`
@ -121,7 +133,7 @@ type (
Ops []GuildMemberListOp `json:"ops"`
}
GuildMemberListGroup struct {
ID string `json:"id"` // either discord.Snowflake Role IDs or "online"
ID string `json:"id"` // either discord.RoleID, "online" or "offline"
Count uint64 `json:"count"`
}
GuildMemberListOp struct {
@ -152,16 +164,16 @@ type (
}
GuildRoleCreateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Role discord.Role `json:"role"`
GuildID discord.GuildID `json:"guild_id"`
Role discord.Role `json:"role"`
}
GuildRoleUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
Role discord.Role `json:"role"`
GuildID discord.GuildID `json:"guild_id"`
Role discord.Role `json:"role"`
}
GuildRoleDeleteEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
RoleID discord.Snowflake `json:"role_id"`
GuildID discord.GuildID `json:"guild_id"`
RoleID discord.RoleID `json:"role_id"`
}
)
@ -171,6 +183,28 @@ func (u GuildMemberUpdateEvent) Update(m *discord.Member) {
m.Nick = u.Nick
}
// https://discord.com/developers/docs/topics/gateway#invites
type (
InviteCreateEvent struct {
Code string `json:"code"`
CreatedAt discord.Timestamp `json:"created_at"`
ChannelID discord.ChannelID `json:"channel_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
// Similar to discord.Invite
Inviter *discord.User `json:"inviter,omitempty"`
Target *discord.User `json:"target_user,omitempty"`
TargetType discord.InviteUserType `json:"target_user_type,omitempty"`
discord.InviteMetadata
}
InviteDeleteEvent struct {
Code string `json:"code"`
ChannelID discord.ChannelID `json:"channel_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#messages
type (
MessageCreateEvent struct {
@ -182,48 +216,48 @@ type (
Member *discord.Member `json:"member,omitempty"`
}
MessageDeleteEvent struct {
ID discord.Snowflake `json:"id"`
ChannelID discord.Snowflake `json:"channel_id"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
ID discord.MessageID `json:"id"`
ChannelID discord.ChannelID `json:"channel_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageDeleteBulkEvent struct {
IDs []discord.Snowflake `json:"ids"`
ChannelID discord.Snowflake `json:"channel_id"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
IDs []discord.MessageID `json:"ids"`
ChannelID discord.ChannelID `json:"channel_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageReactionAddEvent struct {
UserID discord.Snowflake `json:"user_id"`
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
UserID discord.UserID `json:"user_id"`
ChannelID discord.ChannelID `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
Emoji discord.Emoji `json:"emoji,omitempty"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
}
MessageReactionRemoveEvent struct {
UserID discord.Snowflake `json:"user_id"`
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
UserID discord.UserID `json:"user_id"`
ChannelID discord.ChannelID `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
Emoji discord.Emoji `json:"emoji"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageReactionRemoveAllEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
ChannelID discord.ChannelID `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageReactionRemoveEmoji struct {
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.Snowflake `json:"message_id"`
ChannelID discord.ChannelID `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
Emoji discord.Emoji `json:"emoji"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
}
MessageAckEvent struct {
MessageID discord.Snowflake `json:"message_id"`
ChannelID discord.Snowflake `json:"channel_id"`
MessageID discord.MessageID `json:"message_id"`
ChannelID discord.ChannelID `json:"channel_id"`
}
)
@ -254,12 +288,12 @@ type (
}
TypingStartEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
ChannelID discord.ChannelID `json:"channel_id"`
UserID discord.UserID `json:"user_id"`
Timestamp discord.UnixTimestamp `json:"timestamp"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
GuildID discord.GuildID `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`
}
UserUpdateEvent struct {
@ -273,17 +307,17 @@ type (
discord.VoiceState
}
VoiceServerUpdateEvent struct {
Token string `json:"token"`
GuildID discord.Snowflake `json:"guild_id"`
Endpoint string `json:"endpoint"`
Token string `json:"token"`
GuildID discord.GuildID `json:"guild_id"`
Endpoint string `json:"endpoint"`
}
)
// https://discordapp.com/developers/docs/topics/gateway#webhooks
type (
WebhooksUpdateEvent struct {
GuildID discord.Snowflake `json:"guild_id"`
ChannelID discord.Snowflake `json:"channel_id"`
GuildID discord.GuildID `json:"guild_id"`
ChannelID discord.ChannelID `json:"channel_id"`
}
)
@ -296,16 +330,16 @@ type (
UserSettings
}
UserNoteUpdateEvent struct {
ID discord.Snowflake `json:"id"`
Note string `json:"note"`
ID discord.UserID `json:"id"`
Note string `json:"note"`
}
)
type (
RelationshipAdd struct {
Relationship
RelationshipAddEvent struct {
discord.Relationship
}
RelationshipRemove struct {
Relationship
RelationshipRemoveEvent struct {
discord.Relationship
}
)

View file

@ -9,13 +9,11 @@ var EventCreator = map[string]func() Event{
"RESUMED": func() Event { return new(ResumedEvent) },
"INVALID_SESSION": func() Event { return new(InvalidSessionEvent) },
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
"CHANNEL_UNREAD_UPDATE": func() Event {
return new(ChannelUnreadUpdateEvent)
},
"CHANNEL_CREATE": func() Event { return new(ChannelCreateEvent) },
"CHANNEL_UPDATE": func() Event { return new(ChannelUpdateEvent) },
"CHANNEL_DELETE": func() Event { return new(ChannelDeleteEvent) },
"CHANNEL_PINS_UPDATE": func() Event { return new(ChannelPinsUpdateEvent) },
"CHANNEL_UNREAD_UPDATE": func() Event { return new(ChannelUnreadUpdateEvent) },
"GUILD_CREATE": func() Event { return new(GuildCreateEvent) },
"GUILD_UPDATE": func() Event { return new(GuildUpdateEvent) },
@ -24,38 +22,31 @@ var EventCreator = map[string]func() Event{
"GUILD_BAN_ADD": func() Event { return new(GuildBanAddEvent) },
"GUILD_BAN_REMOVE": func() Event { return new(GuildBanRemoveEvent) },
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
"GUILD_INTEGRATIONS_UPDATE": func() Event {
return new(GuildIntegrationsUpdateEvent)
},
"GUILD_EMOJIS_UPDATE": func() Event { return new(GuildEmojisUpdateEvent) },
"GUILD_INTEGRATIONS_UPDATE": func() Event { return new(GuildIntegrationsUpdateEvent) },
"GUILD_MEMBER_ADD": func() Event { return new(GuildMemberAddEvent) },
"GUILD_MEMBER_REMOVE": func() Event { return new(GuildMemberRemoveEvent) },
"GUILD_MEMBER_UPDATE": func() Event { return new(GuildMemberUpdateEvent) },
"GUILD_MEMBERS_CHUNK": func() Event { return new(GuildMembersChunkEvent) },
"GUILD_MEMBER_LIST_UPDATE": func() Event {
return new(GuildMemberListUpdate)
},
"GUILD_MEMBER_LIST_UPDATE": func() Event { return new(GuildMemberListUpdate) },
"GUILD_ROLE_CREATE": func() Event { return new(GuildRoleCreateEvent) },
"GUILD_ROLE_UPDATE": func() Event { return new(GuildRoleUpdateEvent) },
"GUILD_ROLE_DELETE": func() Event { return new(GuildRoleDeleteEvent) },
"INVITE_CREATE": func() Event { return new(InviteCreateEvent) },
"INVITE_DELETE": func() Event { return new(InviteDeleteEvent) },
"MESSAGE_CREATE": func() Event { return new(MessageCreateEvent) },
"MESSAGE_UPDATE": func() Event { return new(MessageUpdateEvent) },
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
"MESSAGE_REACTION_ADD": func() Event {
return new(MessageReactionAddEvent)
},
"MESSAGE_REACTION_REMOVE": func() Event {
return new(MessageReactionRemoveEvent)
},
"MESSAGE_REACTION_REMOVE_ALL": func() Event {
return new(MessageReactionRemoveAllEvent)
},
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
"MESSAGE_ACK": func() Event { return new(MessageAckEvent) },
@ -70,16 +61,11 @@ var EventCreator = map[string]func() Event{
"WEBHOOKS_UPDATE": func() Event { return new(WebhooksUpdateEvent) },
"USER_UPDATE": func() Event {
return new(UserUpdateEvent)
},
"USER_SETTINGS_UPDATE": func() Event {
return new(UserSettingsUpdateEvent)
},
"USER_GUILD_SETTINGS_UPDATE": func() Event {
return new(UserGuildSettingsUpdateEvent)
},
"USER_NOTE_UPDATE": func() Event {
return new(UserNoteUpdateEvent)
},
"USER_UPDATE": func() Event { return new(UserUpdateEvent) },
"USER_SETTINGS_UPDATE": func() Event { return new(UserSettingsUpdateEvent) },
"USER_GUILD_SETTINGS_UPDATE": func() Event { return new(UserGuildSettingsUpdateEvent) },
"USER_NOTE_UPDATE": func() Event { return new(UserNoteUpdateEvent) },
"RELATIONSHIP_ADD": func() Event { return new(RelationshipAddEvent) },
"RELATIONSHIP_REMOVE": func() Event { return new(RelationshipRemoveEvent) },
}

View file

@ -36,16 +36,16 @@ var (
ErrWSMaxTries = errors.New("max tries reached")
)
// GatewayBotData contains the GatewayURL as well as extra metadata on how to
// BotData contains the GatewayURL as well as extra metadata on how to
// shard bots.
type GatewayBotData struct {
type BotData struct {
URL string `json:"url"`
Shards int `json:"shards,omitempty"`
StartLimit *SessionStartLimit `json:"session_start_limit"`
}
// SessionStartLimit is the information on the current session start limit. It's
// used in GatewayBotData.
// used in BotData.
type SessionStartLimit struct {
Total int `json:"total"`
Remaining int `json:"remaining"`
@ -54,7 +54,7 @@ type SessionStartLimit struct {
// URL asks Discord for a Websocket URL to the Gateway.
func URL() (string, error) {
var g GatewayBotData
var g BotData
return g.URL, httputil.NewClient().RequestJSON(
&g, "GET",
@ -64,8 +64,8 @@ func URL() (string, error) {
// BotURL fetches the Gateway URL along with extra metadata. The token
// passed in will NOT be prefixed with Bot.
func BotURL(token string) (*GatewayBotData, error) {
var g *GatewayBotData
func BotURL(token string) (*BotData, error) {
var g *BotData
return g, httputil.NewClient().RequestJSON(
&g, "GET",
@ -85,11 +85,14 @@ type Gateway struct {
// Session.
Events chan Event
// SessionID is used to store the session ID received after Ready. It is not
// thread-safe.
SessionID string
Identifier *Identifier
Sequence *Sequence
PacerLoop *wsutil.PacemakerLoop
PacerLoop wsutil.PacemakerLoop
ErrorLog func(err error) // default to log.Println
@ -98,21 +101,31 @@ type Gateway struct {
// reconnections or any type of connection interruptions.
AfterClose func(err error) // noop by default
// Mutex to hold off calls when the WS is not available. Doesn't block if
// Start() is not called or Close() is called. Also doesn't block for
// Identify or Resume.
// available sync.RWMutex
// Filled by methods, internal use
waitGroup *sync.WaitGroup
}
// NewGateway starts a new Gateway with the default stdlib JSON driver. For more
// information, refer to NewGatewayWithDriver.
// NewGatewayWithIntents creates a new Gateway with the given intents and the
// default stdlib JSON driver. Refer to NewGatewayWithDriver and AddIntents.
func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
g, err := NewGateway(token)
if err != nil {
return nil, err
}
for _, intent := range intents {
g.AddIntent(intent)
}
return g, nil
}
// NewGateway creates a new Gateway with the default stdlib JSON driver. For
// more information, refer to NewGatewayWithDriver.
func NewGateway(token string) (*Gateway, error) {
URL, err := URL()
if err != nil {
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
return nil, errors.Wrap(err, "failed to get gateway endpoint")
}
// Parameters for the gateway
@ -141,51 +154,63 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
}
}
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
// such, this function will only work before Open() is called.
func (g *Gateway) AddIntent(i Intents) {
g.Identifier.Intents |= i
}
// Close closes the underlying Websocket connection.
func (g *Gateway) Close() error {
func (g *Gateway) Close() (err error) {
wsutil.WSDebug("Trying to close.")
// Check if the WS is already closed:
if g.waitGroup == nil && g.PacerLoop.Stopped() {
if g.PacerLoop.Stopped() {
wsutil.WSDebug("Gateway is already closed.")
g.AfterClose(nil)
return nil
return err
}
// Trigger the close callback on exit.
defer func() { g.AfterClose(err) }()
// If the pacemaker is running:
if !g.PacerLoop.Stopped() {
wsutil.WSDebug("Stopping pacemaker...")
// Stop the pacemaker and the event handler
// Stop the pacemaker and the event handler.
g.PacerLoop.Stop()
wsutil.WSDebug("Stopped pacemaker.")
}
wsutil.WSDebug("Closing the websocket...")
err = g.WS.Close()
wsutil.WSDebug("Waiting for WaitGroup to be done.")
// This should work, since Pacemaker should signal its loop to stop, which
// would also exit our event loop. Both would be 2.
g.waitGroup.Wait()
// Mark g.waitGroup as empty:
g.waitGroup = nil
wsutil.WSDebug("WaitGroup is done. Closing the websocket.")
err := g.WS.Close()
g.AfterClose(err)
return err
}
// Reconnect tries to reconnect forever. It will resume the connection if
// possible. If an Invalid Session is received, it will start a fresh one.
func (g *Gateway) Reconnect() error {
return g.ReconnectContext(context.Background())
func (g *Gateway) Reconnect() {
for {
if err := g.ReconnectCtx(context.Background()); err != nil {
g.ErrorLog(err)
} else {
return
}
}
}
func (g *Gateway) ReconnectContext(ctx context.Context) error {
// ReconnectCtx attempts to reconnect until context expires. If context cannot
// expire, then the gateway will try to reconnect forever.
func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
wsutil.WSDebug("Reconnecting...")
// Guarantee the gateway is already closed. Ignore its error, as we're
@ -193,38 +218,52 @@ func (g *Gateway) ReconnectContext(ctx context.Context) error {
g.Close()
for i := 1; ; i++ {
select {
case <-ctx.Done():
return err
default:
}
wsutil.WSDebug("Trying to dial, attempt", i)
// Condition: err == ErrInvalidSession:
// If the connection is rate limited (documented behavior):
// https://discordapp.com/developers/docs/topics/gateway#rate-limiting
if err := g.OpenContext(ctx); err != nil {
g.ErrorLog(errors.Wrap(err, "Failed to open gateway"))
// make sure we don't overwrite our last error
if err = g.OpenContext(ctx); err != nil {
g.ErrorLog(err)
continue
}
wsutil.WSDebug("Started after attempt:", i)
return nil
return
}
}
// Open connects to the Websocket and authenticate it. You should usually use
// this function over Start().
func (g *Gateway) Open() error {
return g.OpenContext(context.Background())
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
return g.OpenContext(ctx)
}
// OpenContext connects to the Websocket and authenticates it. You should
// usually use this function over Start(). The given context provides
// cancellation and timeout.
func (g *Gateway) OpenContext(ctx context.Context) error {
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "Failed to reconnect")
return errors.Wrap(err, "failed to reconnect")
}
wsutil.WSDebug("Trying to start...")
// Try to resume the connection
if err := g.Start(); err != nil {
if err := g.StartCtx(ctx); err != nil {
return err
}
@ -232,14 +271,19 @@ func (g *Gateway) OpenContext(ctx context.Context) error {
return nil
}
// Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block. You wouldn't usually use this
// Start calls StartCtx with a background context. You wouldn't usually use this
// function, but Open() instead.
func (g *Gateway) Start() error {
// g.available.Lock()
// defer g.available.Unlock()
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
if err := g.start(); err != nil {
return g.StartCtx(ctx)
}
// StartCtx authenticates with the websocket, or resume from a dead Websocket
// connection. You wouldn't usually use this function, but OpenCtx() instead.
func (g *Gateway) StartCtx(ctx context.Context) error {
if err := g.start(ctx); err != nil {
wsutil.WSDebug("Start failed:", err)
// Close can be called with the mutex still acquired here, as the
@ -249,32 +293,42 @@ func (g *Gateway) Start() error {
}
return err
}
return nil
}
func (g *Gateway) start() error {
func (g *Gateway) start(ctx context.Context) error {
// This is where we'll get our events
ch := g.WS.Listen()
// Make a new WaitGroup for use in background loops:
g.waitGroup = new(sync.WaitGroup)
// Wait for an OP 10 Hello
// Create a new Hello event and wait for it.
var hello HelloEvent
if _, err := wsutil.AssertEvent(<-ch, HelloOP, &hello); err != nil {
return errors.Wrap(err, "Error at Hello")
// Wait for an OP 10 Hello.
select {
case e, ok := <-ch:
if !ok {
return errors.New("unexpected ws close while waiting for Hello")
}
if _, err := wsutil.AssertEvent(e, HelloOP, &hello); err != nil {
return errors.Wrap(err, "error at Hello")
}
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "failed to wait for Hello event")
}
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID == "" {
// SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil {
return errors.Wrap(err, "Failed to identify")
if err := g.IdentifyCtx(ctx); err != nil {
return errors.Wrap(err, "failed to identify")
}
} else {
if err := g.Resume(); err != nil {
return errors.Wrap(err, "Failed to resume")
if err := g.ResumeCtx(ctx); err != nil {
return errors.Wrap(err, "failed to resume")
}
}
@ -282,7 +336,7 @@ func (g *Gateway) start() error {
wsutil.WSDebug("Waiting for either READY or RESUMED.")
// WaitForEvent should
err := wsutil.WaitForEvent(g, ch, func(op *wsutil.OP) bool {
err := wsutil.WaitForEvent(ctx, g, ch, func(op *wsutil.OP) bool {
switch op.EventName {
case "READY":
wsutil.WSDebug("Found READY event.")
@ -295,16 +349,14 @@ func (g *Gateway) start() error {
})
if err != nil {
return errors.Wrap(err, "First error")
return errors.Wrap(err, "first error")
}
// Use the pacemaker loop.
g.PacerLoop = wsutil.NewLoop(hello.HeartbeatInterval.Duration(), ch, g)
// Start the event handler, which also handles the pacemaker death signal.
g.waitGroup.Add(1)
g.PacerLoop.RunAsync(func(err error) {
// Use the pacemaker loop.
g.PacerLoop.RunAsync(hello.HeartbeatInterval.Duration(), ch, g, func(err error) {
g.waitGroup.Done() // mark so Close() can exit.
wsutil.WSDebug("Event loop stopped with error:", err)
@ -319,7 +371,9 @@ func (g *Gateway) start() error {
return nil
}
func (g *Gateway) Send(code OPCode, v interface{}) error {
// SendCtx is a low-level function to send an OP payload to the Gateway. Most
// users shouldn't touch this, unless they know what they're doing.
func (g *Gateway) SendCtx(ctx context.Context, code OPCode, v interface{}) error {
var op = wsutil.OP{
Code: code,
}
@ -327,7 +381,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
if v != nil {
b, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "Failed to encode v")
return errors.Wrap(err, "failed to encode v")
}
op.Data = b
@ -335,9 +389,9 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
b, err := json.Marshal(op)
if err != nil {
return errors.Wrap(err, "Failed to encode payload")
return errors.Wrap(err, "failed to encode payload")
}
// WS should already be thread-safe.
return g.WS.Send(b)
return g.WS.SendCtx(ctx, b)
}

View file

@ -55,7 +55,7 @@ func (i *IdentifyData) SetShard(id, num int) {
i.Shard[0], i.Shard[1] = id, num
}
// Intents is a new Discord API feature that's documented at
// Intents for the new Discord API feature, documented at
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
type Intents uint32
@ -107,10 +107,10 @@ func NewIdentifier(data IdentifyData) *Identifier {
func (i *Identifier) Wait(ctx context.Context) error {
if err := i.IdentifyShortLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for short limit")
return errors.Wrap(err, "can't wait for short limit")
}
if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for global limit")
return errors.Wrap(err, "can't wait for global limit")
}
return nil
}

View file

@ -3,12 +3,14 @@
package gateway
import (
"context"
"log"
"os"
"strings"
"testing"
"time"
"github.com/diamondburned/arikawa/internal/heart"
"github.com/diamondburned/arikawa/utils/wsutil"
)
@ -16,6 +18,9 @@ func init() {
wsutil.WSDebug = func(v ...interface{}) {
log.Println(append([]interface{}{"Debug:"}, v...)...)
}
heart.Debug = func(v ...interface{}) {
log.Println(append([]interface{}{"Heart:"}, v...)...)
}
}
func TestInvalidToken(t *testing.T) {
@ -78,9 +83,12 @@ func TestIntegration(t *testing.T) {
// Sleep past the rate limiter before reconnecting:
time.Sleep(5 * time.Second)
// Try and reconnect forever:
gotimeout(t, func() {
if err := gateway.Reconnect(); err != nil {
// Try and reconnect for 20 seconds maximum.
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
defer cancel()
if err := gateway.ReconnectCtx(ctx); err != nil {
t.Fatal("Unexpected error while reconnecting:", err)
}
})
@ -107,13 +115,15 @@ func wait(t *testing.T, evCh chan interface{}) interface{} {
select {
case ev := <-evCh:
return ev
case <-time.After(10 * time.Second):
case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for event")
return nil
}
}
func gotimeout(t *testing.T, fn func()) {
t.Helper()
var done = make(chan struct{})
go func() {
fn()
@ -121,7 +131,7 @@ func gotimeout(t *testing.T, fn func()) {
}()
select {
case <-time.After(10 * time.Second):
case <-time.After(20 * time.Second):
t.Fatal("Timed out waiting for function.")
case <-done:
return

View file

@ -1,6 +1,7 @@
package gateway
import (
"context"
"fmt"
"math/rand"
"time"
@ -29,6 +30,11 @@ const (
GuildSubscriptionsOP OPCode = 14
)
// ErrReconnectRequest is returned by HandleOP if a ReconnectOP is given. This
// is used mostly internally to signal the heartbeat loop to reconnect, if
// needed. It is not a fatal error.
var ErrReconnectRequest = errors.New("ReconnectOP received")
func (g *Gateway) HandleOP(op *wsutil.OP) error {
switch op.Code {
case HeartbeatAckOP:
@ -36,32 +42,38 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
g.PacerLoop.Echo()
case HeartbeatOP:
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Server requesting a heartbeat.
return g.PacerLoop.Pace()
if err := g.PacerLoop.Pace(ctx); err != nil {
return wsutil.ErrBrokenConnection(errors.Wrap(err, "failed to pace"))
}
case ReconnectOP:
// Server requests to reconnect, die and retry.
wsutil.WSDebug("ReconnectOP received.")
// We must reconnect in another goroutine, as running Reconnect
// synchronously would prevent the main event loop from exiting.
go g.Reconnect()
// Gracefully exit with a nil let the event handler take the signal from
// the pacemaker.
return nil
// Exit with the ReconnectOP error to force the heartbeat event loop to
// reconnect synchronously. Not really a fatal error.
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
case InvalidSessionOP:
// Discord expects us to sleep for no reason
time.Sleep(time.Duration(rand.Intn(5)+1) * time.Second)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Invalid session, try and Identify.
if err := g.Identify(); err != nil {
if err := g.IdentifyCtx(ctx); err != nil {
// Can't identify, reconnect.
go g.Reconnect()
return wsutil.ErrBrokenConnection(ErrReconnectRequest)
}
return nil
case HelloOP:
// What is this OP doing here???
return nil
case DispatchOP:
@ -74,7 +86,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
fn, ok := EventCreator[op.EventName]
if !ok {
return fmt.Errorf(
"Unknown event %s: %s",
"unknown event %s: %s",
op.EventName, string(op.Data),
)
}
@ -84,7 +96,7 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
// Try and parse the event
if err := json.Unmarshal(op.Data, ev); err != nil {
return errors.Wrap(err, "Failed to parse event "+op.EventName)
return errors.Wrap(err, "failed to parse event "+op.EventName)
}
// If the event is a ready, we'll want its sessionID
@ -92,12 +104,12 @@ func (g *Gateway) HandleOP(op *wsutil.OP) error {
g.SessionID = ev.SessionID
}
// Throw the event into a channel, it's valid now.
// Throw the event into a channel; it's valid now.
g.Events <- ev
return nil
default:
return fmt.Errorf("Unknown OP code %d (event %s)", op.Code, op.EventName)
return fmt.Errorf("unknown OP code %d (event %s)", op.Code, op.EventName)
}
return nil

View file

@ -1,6 +1,11 @@
package gateway
import "github.com/diamondburned/arikawa/discord"
import (
"strconv"
"strings"
"github.com/diamondburned/arikawa/discord"
)
type ReadyEvent struct {
Version int `json:"version"`
@ -20,8 +25,8 @@ type ReadyEvent struct {
ReadState []ReadState `json:"read_state,omitempty"`
Presences []discord.Presence `json:"presences,omitempty"`
Relationships []Relationship `json:"relationships,omitempty"`
Notes map[discord.Snowflake]string `json:"notes,omitempty"`
Relationships []discord.Relationship `json:"relationships,omitempty"`
Notes map[discord.UserID]string `json:"notes,omitempty"`
}
type UserSettings struct {
@ -41,16 +46,16 @@ type UserSettings struct {
DeveloperMode bool `json:"developer_mode"`
DetectPlatformAccounts bool `json:"detect_platform_accounts"`
StreamNotification bool `json:"stream_notification_enabled"`
AccessibilityDetection bool `json:"allow_accessbility_detection"`
AccessibilityDetection bool `json:"allow_accessibility_detection"`
ContactSync bool `json:"contact_sync_enabled"`
NativePhoneIntegration bool `json:"native_phone_integration_enabled"`
Locale string `json:"locale"`
Theme string `json:"theme"`
GuildPositions []discord.Snowflake `json:"guild_positions"`
GuildFolders []GuildFolder `json:"guild_folders"`
RestrictedGuilds []discord.Snowflake `json:"restricted_guilds"`
GuildPositions []discord.GuildID `json:"guild_positions"`
GuildFolders []GuildFolder `json:"guild_folders"`
RestrictedGuilds []discord.GuildID `json:"restricted_guilds"`
FriendSourceFlags struct {
All bool `json:"all"`
@ -62,19 +67,19 @@ type UserSettings struct {
CustomStatus struct {
Text string `json:"text"`
ExpiresAt discord.Timestamp `json:"expires_at,omitempty"`
EmojiID discord.Snowflake `json:"emoji_id,string"`
EmojiID discord.EmojiID `json:"emoji_id,string"`
EmojiName string `json:"emoji_name"`
} `json:"custom_status"`
}
// A UserGuildSettings stores data for a users guild settings.
type UserGuildSettings struct {
GuildID discord.Snowflake `json:"guild_id"`
GuildID discord.GuildID `json:"guild_id"`
SupressEveryone bool `json:"suppress_everyone"`
SupressRoles bool `json:"suppress_roles"`
Muted bool `json:"muted"`
MobilePush bool `json:"mobile_push"`
SuppressEveryone bool `json:"suppress_everyone"`
SuppressRoles bool `json:"suppress_roles"`
Muted bool `json:"muted"`
MobilePush bool `json:"mobile_push"`
MessageNotifications UserNotification `json:"message_notifications"`
ChannelOverrides []SettingsChannelOverride `json:"channel_overrides"`
@ -91,8 +96,8 @@ const (
)
type ReadState struct {
ChannelID discord.Snowflake `json:"id"`
LastMessageID discord.Snowflake `json:"last_message_id"`
ChannelID discord.ChannelID `json:"id"`
LastMessageID discord.MessageID `json:"last_message_id"`
MentionCount int `json:"mention_count"`
}
@ -102,30 +107,42 @@ type SettingsChannelOverride struct {
Muted bool `json:"muted"`
MessageNotifications UserNotification `json:"message_notifications"`
ChannelID discord.Snowflake `json:"channel_id"`
ChannelID discord.ChannelID `json:"channel_id"`
}
// GuildFolder holds a single folder that you see in the left guild panel.
type GuildFolder struct {
Name string `json:"name"`
ID discord.Snowflake `json:"id"`
GuildIDs []discord.Snowflake `json:"guild_ids"`
Color discord.Color `json:"color"`
Name string `json:"name"`
ID GuildFolderID `json:"id"`
GuildIDs []discord.GuildID `json:"guild_ids"`
Color discord.Color `json:"color"`
}
// A Relationship between the logged in user and Relationship.User
type Relationship struct {
ID string `json:"id"`
User discord.User `json:"user"`
Type RelationshipType `json:"type"`
// GuildFolderID is possibly a snowflake. It can also be 0 (null) or a low
// number of unknown significance.
type GuildFolderID int64
func (g *GuildFolderID) UnmarshalJSON(b []byte) error {
var body = string(b)
if body == "null" {
return nil
}
body = strings.Trim(body, `"`)
u, err := strconv.ParseInt(body, 10, 64)
if err != nil {
return err
}
*g = GuildFolderID(u)
return nil
}
type RelationshipType uint8
func (g GuildFolderID) MarshalJSON() ([]byte, error) {
if g == 0 {
return []byte("null"), nil
}
const (
_ RelationshipType = iota
FriendRelationship
BlockedRelationship
IncomingFriendRequest
SentFriendRequest
)
return []byte(strconv.FormatInt(int64(g), 10)), nil
}

1
go.sum
View file

@ -16,3 +16,4 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 h1:NusfzzA6yGQ+ua51ck7E3omNUX/JuqbFSaRGqU8CcLI=
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e h1:EHBhcS0mlXEAVwNyO2dLfjToGsyY4j24pTs2ScHnX7s=

View file

@ -1,245 +0,0 @@
// Package handler handles incoming Gateway events. It reflects the function's
// first argument and caches that for use in each event.
//
// Performance
//
// Each call to the event would take 156 ns/op for roughly each handler. Scaling
// that up to 100 handlers is multiplying 156 ns by 100, which gives 15600 ns,
// or 0.0156 ms.
//
// BenchmarkReflect-8 7260909 156 ns/op
//
// Usage
//
// Handler's usage is similar to discordgo, in that AddHandler expects a
// function with only one argument. The only argument must be a pointer to one
// of the events, or an interface{} which would accept all events.
//
// AddHandler would panic if the handler is invalid.
//
// s.AddHandler(func(m *gateway.MessageCreateEvent) {
// log.Println(m.Author.Username, "said", m.Content)
// })
//
package handler
import (
"context"
"fmt"
"reflect"
"sync"
"github.com/pkg/errors"
)
type Handler struct {
// Synchronous controls whether to spawn each event handler in its own
// goroutine. Default false (meaning goroutines are spawned).
Synchronous bool
handlers map[uint64]handler
horders []uint64
hserial uint64
hmutex sync.RWMutex
}
func New() *Handler {
return &Handler{
handlers: map[uint64]handler{},
}
}
func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
if !ok {
// This shouldn't ever happen, but we're adding this just in case.
continue
}
if handler.not(evT) {
continue
}
if h.Synchronous {
handler.call(evV)
} else {
go handler.call(evV)
}
}
}
// WaitFor blocks until there's an event. It's advised to use ChanFor instead,
// as WaitFor may skip some events if it's not ran fast enough after the event
// arrived.
func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interface{} {
var result = make(chan interface{})
cancel := h.AddHandler(func(v interface{}) {
if fn(v) {
result <- v
}
})
defer cancel()
select {
case r := <-result:
return r
case <-ctx.Done():
return nil
}
}
// ChanFor returns a channel that would receive all incoming events that match
// the callback given. The cancel() function removes the handler and drops all
// hanging goroutines.
func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, cancel func()) {
result := make(chan interface{})
closer := make(chan struct{})
removeHandler := h.AddHandler(func(v interface{}) {
if fn(v) {
select {
case result <- v:
case <-closer:
}
}
})
// Only allow cancel to be called once.
var once sync.Once
cancel = func() {
once.Do(func() {
removeHandler()
close(closer)
})
}
out = result
return
}
// AddHandler adds the handler, returning a function that would remove this
// handler when called.
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler)
if err != nil {
panic(err)
}
return rm
}
// AddHandlerCheck adds the handler, but safe-guards reflect panics with a
// recoverer, returning the error.
func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
// Reflect would actually panic if anything goes wrong, so this is just in
// case.
defer func() {
if rec := recover(); rec != nil {
if recErr, ok := rec.(error); ok {
err = recErr
} else {
err = fmt.Errorf("%v", rec)
}
}
}()
return h.addHandler(handler)
}
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
// Reflect the handler
r, err := reflectFn(fn)
if err != nil {
return nil, errors.Wrap(err, "Handler reflect failed")
}
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Get the current counter value and increment the counter:
serial := h.hserial
h.hserial++
// Create a map if there's none:
if h.handlers == nil {
h.handlers = map[uint64]handler{}
}
// Use the serial for the map:
h.handlers[serial] = *r
// Append the serial into the list of keys:
h.horders = append(h.horders, serial)
return func() {
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Delete the handler from the map:
delete(h.handlers, serial)
// Delete the key from the orders slice:
for i, order := range h.horders {
if order == serial {
h.horders = append(h.horders[:i], h.horders[i+1:]...)
break
}
}
}, nil
}
type handler struct {
event reflect.Type
callback reflect.Value
isIface bool
}
func reflectFn(function interface{}) (*handler, error) {
fnV := reflect.ValueOf(function)
fnT := fnV.Type()
if fnT.Kind() != reflect.Func {
return nil, errors.New("given interface is not a function")
}
if fnT.NumIn() != 1 {
return nil, errors.New("function can only accept 1 event as argument")
}
if fnT.NumOut() > 0 {
return nil, errors.New("function can't accept returns")
}
argT := fnT.In(0)
kind := argT.Kind()
// Accept either pointer type or interface{} type
if kind != reflect.Ptr && kind != reflect.Interface {
return nil, errors.New("first argument is not pointer")
}
return &handler{
event: argT,
callback: fnV,
isIface: kind == reflect.Interface,
}, nil
}
func (h handler) not(event reflect.Type) bool {
if h.isIface {
return !event.Implements(h.event)
}
return h.event != event
}
func (h handler) call(event reflect.Value) {
h.callback.Call([]reflect.Value{event})
}

165
internal/heart/heart.go Normal file
View file

@ -0,0 +1,165 @@
// Package heart implements a general purpose pacemaker.
package heart
import (
"context"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
// Debug is the default logger that Pacemaker uses.
var Debug = func(v ...interface{}) {}
var ErrDead = errors.New("no heartbeat replied")
// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic.
type AtomicTime struct {
unixnano int64
}
func (t *AtomicTime) Get() int64 {
return atomic.LoadInt64(&t.unixnano)
}
func (t *AtomicTime) Set(time time.Time) {
atomic.StoreInt64(&t.unixnano, time.UnixNano())
}
func (t *AtomicTime) Time() time.Time {
return time.Unix(0, t.Get())
}
type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
ticker time.Ticker
Ticks <-chan time.Time
// Time in nanoseconds, guarded by atomic read/writes.
SentBeat AtomicTime
EchoBeat AtomicTime
// Any callback that returns an error will stop the pacer.
Pacer func(context.Context) error
}
func NewPacemaker(heartrate time.Duration, pacer func(context.Context) error) Pacemaker {
p := Pacemaker{
Heartrate: heartrate,
Pacer: pacer,
ticker: *time.NewTicker(heartrate),
}
p.Ticks = p.ticker.C
// Reset states to its old position.
now := time.Now()
p.EchoBeat.Set(now)
p.SentBeat.Set(now)
return p
}
func (p *Pacemaker) Echo() {
// Swap our received heartbeats
p.EchoBeat.Set(time.Now())
}
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
var (
echo = p.EchoBeat.Get()
sent = p.SentBeat.Get()
)
if echo == 0 || sent == 0 {
return false
}
return sent-echo > int64(p.Heartrate)*2
}
// Stop stops the pacemaker, or it does nothing if the pacemaker is not started.
func (p *Pacemaker) Stop() {
p.ticker.Stop()
}
// pace sends a heartbeat with the appropriate timeout for the context.
func (p *Pacemaker) Pace() error {
ctx, cancel := context.WithTimeout(context.Background(), p.Heartrate)
defer cancel()
return p.PaceCtx(ctx)
}
func (p *Pacemaker) PaceCtx(ctx context.Context) error {
if err := p.Pacer(ctx); err != nil {
return err
}
p.SentBeat.Set(time.Now())
if p.Dead() {
return ErrDead
}
return nil
}
// func (p *Pacemaker) start() error {
// // Reset states to its old position.
// p.EchoBeat.Set(time.Time{})
// p.SentBeat.Set(time.Time{})
// // Create a new ticker.
// tick := time.NewTicker(p.Heartrate)
// defer tick.Stop()
// // Echo at least once
// p.Echo()
// for {
// if err := p.pace(); err != nil {
// return errors.Wrap(err, "failed to pace")
// }
// // Paced, save:
// p.SentBeat.Set(time.Now())
// if p.Dead() {
// return ErrDead
// }
// select {
// case <-p.stop:
// return nil
// case <-tick.C:
// }
// }
// }
// // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
// func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
// p.death = make(chan error)
// p.stop = make(chan struct{})
// p.once = sync.Once{}
// if wg != nil {
// wg.Add(1)
// }
// go func() {
// p.death <- p.start()
// // Debug.
// Debug("Pacemaker returned.")
// // Mark the pacemaker loop as done.
// if wg != nil {
// wg.Done()
// }
// }()
// return p.death
// }

View file

@ -0,0 +1,51 @@
package moreatomic
import (
"sync"
"github.com/diamondburned/arikawa/discord"
)
type GuildIDSet struct {
set map[discord.GuildID]struct{}
mut sync.Mutex
}
// NewGuildIDSet creates a new GuildIDSet.
func NewGuildIDSet() *GuildIDSet {
return &GuildIDSet{
set: make(map[discord.GuildID]struct{}),
}
}
// Add adds the passed discord.GuildID to the set.
func (s *GuildIDSet) Add(flake discord.GuildID) {
s.mut.Lock()
s.set[flake] = struct{}{}
s.mut.Unlock()
}
// Contains checks whether the passed discord.GuildID is present in the set.
func (s *GuildIDSet) Contains(flake discord.GuildID) (ok bool) {
s.mut.Lock()
defer s.mut.Unlock()
_, ok = s.set[flake]
return
}
// Delete deletes the passed discord.GuildID from the set and returns true if
// the element is present. If not, Delete is a no-op and returns false.
func (s *GuildIDSet) Delete(flake discord.GuildID) bool {
s.mut.Lock()
defer s.mut.Unlock()
if _, ok := s.set[flake]; ok {
delete(s.set, flake)
return true
}
return false
}

View file

@ -42,6 +42,16 @@ func (m *CtxMutex) Lock(ctx context.Context) error {
}
}
// TryUnlock returns true if the mutex has been unlocked.
func (m *CtxMutex) TryUnlock() bool {
select {
case <-m.mut:
return true
default:
return false
}
}
func (m *CtxMutex) Unlock() {
select {
case <-m.mut:

View file

@ -0,0 +1,51 @@
package moreatomic
import (
"sync"
"github.com/diamondburned/arikawa/discord"
)
type SnowflakeSet struct {
set map[discord.Snowflake]struct{}
mut sync.Mutex
}
// NewSnowflakeSet creates a new SnowflakeSet.
func NewSnowflakeSet() *SnowflakeSet {
return &SnowflakeSet{
set: make(map[discord.Snowflake]struct{}),
}
}
// Add adds the passed discord.Snowflake to the set.
func (s *SnowflakeSet) Add(flake discord.Snowflake) {
s.mut.Lock()
s.set[flake] = struct{}{}
s.mut.Unlock()
}
// Contains checks whether the passed discord.Snowflake is present in the set.
func (s *SnowflakeSet) Contains(flake discord.Snowflake) (ok bool) {
s.mut.Lock()
defer s.mut.Unlock()
_, ok = s.set[flake]
return
}
// Delete deletes the passed discord.Snowflake from the set and returns true if
// the element is present. If not, Delete is a no-op and returns false.
func (s *SnowflakeSet) Delete(flake discord.Snowflake) bool {
s.mut.Lock()
defer s.mut.Unlock()
if _, ok := s.set[flake]; ok {
delete(s.set, flake)
return true
}
return false
}

View file

@ -0,0 +1 @@
package mulipartutil

View file

@ -61,14 +61,14 @@ func (i *Inflator) Flush() ([]byte, error) {
if i.zlib == nil {
r, err := zlibStreamer(&i.wbuf)
if err != nil {
return nil, errors.Wrap(err, "Failed to make a FLATE reader")
return nil, errors.Wrap(err, "failed to make a FLATE reader")
}
// safe assertion
i.zlib = r
// } else {
// // Reset the FLATE reader for future use:
// if err := i.zlib.Reset(&i.wbuf, nil); err != nil {
// return nil, errors.Wrap(err, "Failed to reset zlib reader")
// return nil, errors.Wrap(err, "failed to reset zlib reader")
// }
}
@ -79,12 +79,12 @@ func (i *Inflator) Flush() ([]byte, error) {
// to verify checksum. Discord doesn't send this.
if err != nil {
// Unexpected error, try and close.
return nil, errors.Wrap(err, "Failed to read from FLATE reader")
return nil, errors.Wrap(err, "failed to read from FLATE reader")
}
// if err := i.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
// // Try and close anyway.
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
// return nil, errors.Wrap(err, "failed to read from zlib reader")
// }
// Copy the bytes.
@ -105,7 +105,7 @@ func (i *Inflator) Flush() ([]byte, error) {
// if d.zlib == nil {
// r, err := zlib.NewReader(&d.wbuf)
// if err != nil {
// return nil, errors.Wrap(err, "Failed to make a zlib reader")
// return nil, errors.Wrap(err, "failed to make a zlib reader")
// }
// // safe assertion
// d.zlib = r
@ -125,12 +125,12 @@ func (i *Inflator) Flush() ([]byte, error) {
// // to verify checksum. Discord doesn't send this.
// // if err != nil && err != io.ErrUnexpectedEOF {
// // // Unexpected error, try and close.
// // return nil, errors.Wrap(err, "Failed to read from zlib reader")
// // return nil, errors.Wrap(err, "failed to read from zlib reader")
// // }
// if err := d.zlib.Close(); err != nil && err != io.ErrUnexpectedEOF {
// // Try and close anyway.
// return nil, errors.Wrap(err, "Failed to read from zlib reader")
// return nil, errors.Wrap(err, "failed to read from zlib reader")
// }
// // Copy the bytes.

View file

@ -4,14 +4,19 @@
package session
import (
"sync"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/utils/handler"
)
var ErrMFA = errors.New("account has 2FA enabled")
// Closed is an event that's sent to Session's command handler. This works by
// using (*Gateway).AfterError. If the user sets this callback, no Closed events
// using (*Gateway).AfterClose. If the user sets this callback, no Closed events
// would be sent.
//
// Usage
@ -22,8 +27,6 @@ type Closed struct {
Error error
}
var ErrMFA = errors.New("Account has 2FA enabled")
// Session manages both the API and Gateway. As such, Session inherits all of
// API's methods, as well has the Handler used for Gateway.
type Session struct {
@ -38,22 +41,28 @@ type Session struct {
Ticket string
hstop chan struct{}
wstop sync.Once
}
func New(token string) (*Session, error) {
// Initialize the session and the API interface
s := &Session{}
s.Handler = handler.New()
s.Client = api.NewClient(token)
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
g, err := gateway.NewGatewayWithIntents(token, intents...)
if err != nil {
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
return NewWithGateway(g), nil
}
// New creates a new session from a given token. Most bots should be using
// NewWithIntents instead.
func New(token string) (*Session, error) {
// Create a gateway
g, err := gateway.NewGateway(token)
if err != nil {
return nil, errors.Wrap(err, "Failed to connect to Gateway")
return nil, errors.Wrap(err, "failed to connect to Gateway")
}
s.Gateway = g
return s, nil
return NewWithGateway(g), nil
}
// Login tries to log in as a normal user account; MFA is optional.
@ -64,7 +73,7 @@ func Login(email, password, mfa string) (*Session, error) {
// Try to login without TOTP
l, err := client.Login(email, password)
if err != nil {
return nil, errors.Wrap(err, "Failed to login")
return nil, errors.Wrap(err, "failed to login")
}
if l.Token != "" && !l.MFA {
@ -80,7 +89,7 @@ func Login(email, password, mfa string) (*Session, error) {
// Retry logging in with a 2FA token
l, err = client.TOTP(mfa, l.Ticket)
if err != nil {
return nil, errors.Wrap(err, "Failed to login with 2FA")
return nil, errors.Wrap(err, "failed to login with 2FA")
}
return New(l.Token)
@ -97,9 +106,9 @@ func NewWithGateway(gw *gateway.Gateway) *Session {
func (s *Session) Open() error {
// Start the handler beforehand so no events are missed.
stop := make(chan struct{})
s.hstop = stop
go s.startHandler(stop)
s.hstop = make(chan struct{})
s.wstop = sync.Once{}
go s.startHandler()
// Set the AfterClose's handler.
s.Gateway.AfterClose = func(err error) {
@ -109,33 +118,26 @@ func (s *Session) Open() error {
}
if err := s.Gateway.Open(); err != nil {
return errors.Wrap(err, "Failed to start gateway")
return errors.Wrap(err, "failed to start gateway")
}
return nil
}
func (s *Session) startHandler(stop <-chan struct{}) {
func (s *Session) startHandler() {
for {
select {
case <-stop:
case <-s.hstop:
return
case ev := <-s.Gateway.Events:
s.Handler.Call(ev)
s.Call(ev)
}
}
}
func (s *Session) Close() error {
// Stop the event handler
s.close()
s.wstop.Do(func() { s.hstop <- struct{}{} })
// Close the websocket
return s.Gateway.Close()
}
func (s *Session) close() {
if s.hstop != nil {
close(s.hstop)
}
}

58
state/event_dispatcher.go Normal file
View file

@ -0,0 +1,58 @@
package state
import (
"github.com/diamondburned/arikawa/gateway"
)
func (s *State) handleReady(ev *gateway.ReadyEvent) {
for _, g := range ev.Guilds {
// store this so we know when we need to dispatch a belated
// GuildReadyEvent
if g.Unavailable {
s.unreadyGuilds.Add(g.ID)
} else {
s.Handler.Call(&GuildReadyEvent{
GuildCreateEvent: &g,
})
}
}
}
func (s *State) handleGuildCreate(ev *gateway.GuildCreateEvent) {
// this guild was unavailable, but has come back online
if s.unavailableGuilds.Delete(ev.ID) {
s.Handler.Call(&GuildAvailableEvent{
GuildCreateEvent: ev,
})
// the guild was already unavailable when connecting to the gateway
// we can dispatch a belated GuildReadyEvent
} else if s.unreadyGuilds.Delete(ev.ID) {
s.Handler.Call(&GuildReadyEvent{
GuildCreateEvent: ev,
})
} else { // we don't know this guild, hence we just joined it
s.Handler.Call(&GuildJoinEvent{
GuildCreateEvent: ev,
})
}
}
func (s *State) handleGuildDelete(ev *gateway.GuildDeleteEvent) {
// store this so we can later dispatch a GuildAvailableEvent, once the
// guild becomes available again.
if ev.Unavailable {
s.unavailableGuilds.Add(ev.ID)
s.Handler.Call(&GuildUnavailableEvent{
GuildDeleteEvent: ev,
})
} else {
// it might have been unavailable before we left
s.unavailableGuilds.Delete(ev.ID)
s.Handler.Call(&GuildLeaveEvent{
GuildDeleteEvent: ev,
})
}
}

42
state/events.go Normal file
View file

@ -0,0 +1,42 @@
package state
import "github.com/diamondburned/arikawa/gateway"
// events that originated from GuildCreate:
type (
// GuildReady gets fired for every guild the bot/user is in, as found in
// the Ready event.
//
// Guilds that are unavailable when connecting, will not trigger a
// GuildReadyEvent, until they become available again.
GuildReadyEvent struct {
*gateway.GuildCreateEvent
}
// GuildAvailableEvent gets fired when a guild becomes available again,
// after being previously declared unavailable through a
// GuildUnavailableEvent. This event will not be fired for guilds that
// were already unavailable when connecting to the gateway.
GuildAvailableEvent struct {
*gateway.GuildCreateEvent
}
// GuildJoinEvent gets fired if the bot/user joins a guild.
GuildJoinEvent struct {
*gateway.GuildCreateEvent
}
)
// events that originated from GuildDelete:
type (
// GuildLeaveEvent gets fired if the bot/user left a guild, was removed
// or the owner deleted the guild.
GuildLeaveEvent struct {
*gateway.GuildDeleteEvent
}
// GuildUnavailableEvent gets fired if a guild becomes unavailable.
GuildUnavailableEvent struct {
*gateway.GuildDeleteEvent
}
)

View file

@ -8,16 +8,54 @@ import (
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/handler"
"github.com/diamondburned/arikawa/internal/moreatomic"
"github.com/diamondburned/arikawa/session"
"github.com/diamondburned/arikawa/utils/handler"
"github.com/pkg/errors"
)
var (
MaxFetchMembers uint = 1000
MaxFetchGuilds uint = 100
MaxFetchGuilds uint = 10
)
// State is the cache to store events coming from Discord as well as data from
// API calls.
//
// Store
//
// The state basically provides abstractions on top of the API and the state
// storage (Store). The state storage is effectively a set of interfaces which
// allow arbitrary backends to be implemented.
//
// The default storage backend is a typical in-memory structure consisting of
// maps and slices. Custom backend implementations could embed this storage
// backend as an in-memory fallback. A good example of this would be embedding
// the default store for messages only, while handling everything else in Redis.
//
// The package also provides a no-op store (NoopStore) that implementations
// could embed. This no-op store will always return an error, which makes the
// state fetch information from the API. The setters are all no-ops, so the
// fetched data won't be updated.
//
// Handler
//
// The state uses its own handler over session's to make all handlers run after
// the state updates itself. A PreHandler is exposed in any case the user needs
// the handlers to run before the state updates itself. Refer to that field's
// documentation.
//
// The state also provides extra events and overrides to make up for Discord's
// inconsistencies in data. The following are known instances of such.
//
// The Guild Create event is split up to make the state's Guild Available, Guild
// Ready and Guild Join events. Refer to these events' documentations for more
// information.
//
// The Message Create and Message Update events with the Member field provided
// will have the User field copied from Author. This is because the User field
// will be empty, while the Member structure expects it to be there.
type State struct {
*session.Session
Store
@ -39,21 +77,40 @@ type State struct {
// Command handler with inherited methods. Ran after PreHandler. You should
// most of the time use this instead of Session's, to avoid race conditions
// with the State
// with the State.
*handler.Handler
unhooker func()
// List of channels with few messages, so it doesn't bother hitting the API
// again.
fewMessages map[discord.Snowflake]struct{}
fewMessages map[discord.ChannelID]struct{}
fewMutex *sync.Mutex
// unavailableGuilds is a set of discord.GuildIDs of guilds that became
// unavailable when already connected to the gateway, i.e. sent in a
// GuildUnavailableEvent.
unavailableGuilds *moreatomic.GuildIDSet
// unreadyGuilds is a set of discord.GuildIDs of guilds that were
// unavailable when connecting to the gateway, i.e. they had Unavailable
// set to true during Ready.
unreadyGuilds *moreatomic.GuildIDSet
}
// New creates a new state.
func New(token string) (*State, error) {
return NewWithStore(token, NewDefaultStore(nil))
}
// NewWithIntents creates a new state with the given gateway intents. For more
// information, refer to gateway.Intents.
func NewWithIntents(token string, intents ...gateway.Intents) (*State, error) {
s, err := session.NewWithIntents(token, intents...)
if err != nil {
return nil, err
}
return NewFromSession(s, NewDefaultStore(nil))
}
func NewWithStore(token string, store Store) (*State, error) {
s, err := session.New(token)
if err != nil {
@ -63,17 +120,21 @@ func NewWithStore(token string, store Store) (*State, error) {
return NewFromSession(s, store)
}
// NewFromSession never returns an error. This API is kept for backwards
// compatibility.
func NewFromSession(s *session.Session, store Store) (*State, error) {
state := &State{
Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.Snowflake]struct{}{},
fewMutex: new(sync.Mutex),
Session: s,
Store: store,
Handler: handler.New(),
StateLog: func(err error) {},
fewMessages: map[discord.ChannelID]struct{}{},
fewMutex: new(sync.Mutex),
unavailableGuilds: moreatomic.NewGuildIDSet(),
unreadyGuilds: moreatomic.NewGuildIDSet(),
}
return state, state.hookSession()
state.hookSession()
return state, nil
}
// WithContext returns a shallow copy of State with the context replaced in the
@ -89,7 +150,7 @@ func (s *State) WithContext(ctx context.Context) *State {
//// Helper methods
func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {
if !message.GuildID.Valid() {
if !message.GuildID.IsValid() {
return message.Author.Username
}
@ -108,7 +169,7 @@ func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {
return n
}
func (s *State) MemberDisplayName(guildID, userID discord.Snowflake) (string, error) {
func (s *State) MemberDisplayName(guildID discord.GuildID, userID discord.UserID) (string, error) {
member, err := s.Member(guildID, userID)
if err != nil {
return "", err
@ -121,52 +182,92 @@ func (s *State) MemberDisplayName(guildID, userID discord.Snowflake) (string, er
return member.Nick, nil
}
func (s *State) AuthorColor(message *gateway.MessageCreateEvent) discord.Color {
if !message.GuildID.Valid() {
return discord.DefaultMemberColor
func (s *State) AuthorColor(message *gateway.MessageCreateEvent) (discord.Color, error) {
if !message.GuildID.IsValid() { // this is a dm
return discord.DefaultMemberColor, nil
}
if message.Member != nil {
guild, err := s.Guild(message.GuildID)
if err != nil {
return discord.DefaultMemberColor
return 0, err
}
return discord.MemberColor(*guild, *message.Member)
return discord.MemberColor(*guild, *message.Member), nil
}
return s.MemberColor(message.GuildID, message.Author.ID)
}
func (s *State) MemberColor(guildID, userID discord.Snowflake) discord.Color {
member, err := s.Member(guildID, userID)
if err != nil {
return discord.DefaultMemberColor
func (s *State) MemberColor(guildID discord.GuildID, userID discord.UserID) (discord.Color, error) {
var wg sync.WaitGroup
g, gerr := s.Store.Guild(guildID)
m, merr := s.Store.Member(guildID, userID)
switch {
case gerr != nil && merr != nil:
wg.Add(1)
go func() {
g, gerr = s.fetchGuild(guildID)
wg.Done()
}()
m, merr = s.fetchMember(guildID, userID)
case gerr != nil:
g, gerr = s.fetchGuild(guildID)
case merr != nil:
m, merr = s.fetchMember(guildID, userID)
}
guild, err := s.Guild(guildID)
if err != nil {
return discord.DefaultMemberColor
wg.Wait()
if gerr != nil {
return 0, errors.Wrap(merr, "failed to get guild")
}
if merr != nil {
return 0, errors.Wrap(merr, "failed to get member")
}
return discord.MemberColor(*guild, *member)
return discord.MemberColor(*g, *m), nil
}
////
func (s *State) Permissions(channelID, userID discord.Snowflake) (discord.Permissions, error) {
func (s *State) Permissions(
channelID discord.ChannelID, userID discord.UserID) (discord.Permissions, error) {
ch, err := s.Channel(channelID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get channel")
return 0, errors.Wrap(err, "failed to get channel")
}
g, err := s.Guild(ch.GuildID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get guild")
var wg sync.WaitGroup
g, gerr := s.Store.Guild(ch.GuildID)
m, merr := s.Store.Member(ch.GuildID, userID)
switch {
case gerr != nil && merr != nil:
wg.Add(1)
go func() {
g, gerr = s.fetchGuild(ch.GuildID)
wg.Done()
}()
m, merr = s.fetchMember(ch.GuildID, userID)
case gerr != nil:
g, gerr = s.fetchGuild(ch.GuildID)
case merr != nil:
m, merr = s.fetchMember(ch.GuildID, userID)
}
m, err := s.Member(ch.GuildID, userID)
if err != nil {
return 0, errors.Wrap(err, "Failed to get member")
wg.Wait()
if gerr != nil {
return 0, errors.Wrap(merr, "failed to get guild")
}
if merr != nil {
return 0, errors.Wrap(merr, "failed to get member")
}
return discord.CalcOverwrites(*g, *ch, *m), nil
@ -185,12 +286,12 @@ func (s *State) Me() (*discord.User, error) {
return nil, err
}
return u, s.Store.MyselfSet(u)
return u, s.Store.MyselfSet(*u)
}
////
func (s *State) Channel(id discord.Snowflake) (*discord.Channel, error) {
func (s *State) Channel(id discord.ChannelID) (*discord.Channel, error) {
c, err := s.Store.Channel(id)
if err == nil {
return c, nil
@ -201,10 +302,10 @@ func (s *State) Channel(id discord.Snowflake) (*discord.Channel, error) {
return nil, err
}
return c, s.Store.ChannelSet(c)
return c, s.Store.ChannelSet(*c)
}
func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
func (s *State) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
c, err := s.Store.Channels(guildID)
if err == nil {
return c, nil
@ -218,7 +319,7 @@ func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
if err := s.Store.ChannelSet(ch); err != nil {
return nil, err
}
}
@ -226,7 +327,7 @@ func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
return c, nil
}
func (s *State) CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error) {
func (s *State) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
c, err := s.Store.CreatePrivateChannel(recipient)
if err == nil {
return c, nil
@ -237,7 +338,7 @@ func (s *State) CreatePrivateChannel(recipient discord.Snowflake) (*discord.Chan
return nil, err
}
return c, s.Store.ChannelSet(c)
return c, s.Store.ChannelSet(*c)
}
func (s *State) PrivateChannels() ([]discord.Channel, error) {
@ -254,7 +355,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
for _, ch := range c {
ch := ch
if err := s.Store.ChannelSet(&ch); err != nil {
if err := s.Store.ChannelSet(ch); err != nil {
return nil, err
}
}
@ -265,7 +366,7 @@ func (s *State) PrivateChannels() ([]discord.Channel, error) {
////
func (s *State) Emoji(
guildID, emojiID discord.Snowflake) (*discord.Emoji, error) {
guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
e, err := s.Store.Emoji(guildID, emojiID)
if err == nil {
@ -290,7 +391,7 @@ func (s *State) Emoji(
return nil, ErrStoreNotFound
}
func (s *State) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
func (s *State) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
e, err := s.Store.Emojis(guildID)
if err == nil {
return e, nil
@ -306,18 +407,13 @@ func (s *State) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
////
func (s *State) Guild(id discord.Snowflake) (*discord.Guild, error) {
func (s *State) Guild(id discord.GuildID) (*discord.Guild, error) {
c, err := s.Store.Guild(id)
if err == nil {
return c, nil
}
c, err = s.Session.Guild(id)
if err != nil {
return nil, err
}
return c, s.Store.GuildSet(c)
return s.fetchGuild(id)
}
// Guilds will only fill a maximum of 100 guilds from the API.
@ -335,7 +431,7 @@ func (s *State) Guilds() ([]discord.Guild, error) {
for _, ch := range c {
ch := ch
if err := s.Store.GuildSet(&ch); err != nil {
if err := s.Store.GuildSet(ch); err != nil {
return nil, err
}
}
@ -345,23 +441,16 @@ func (s *State) Guilds() ([]discord.Guild, error) {
////
func (s *State) Member(
guildID, userID discord.Snowflake) (*discord.Member, error) {
func (s *State) Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
m, err := s.Store.Member(guildID, userID)
if err == nil {
return m, nil
}
m, err = s.Session.Member(guildID, userID)
if err != nil {
return nil, err
}
return m, s.Store.MemberSet(guildID, m)
return s.fetchMember(guildID, userID)
}
func (s *State) Members(guildID discord.Snowflake) ([]discord.Member, error) {
func (s *State) Members(guildID discord.GuildID) ([]discord.Member, error) {
ms, err := s.Store.Members(guildID)
if err == nil {
return ms, nil
@ -373,13 +462,13 @@ func (s *State) Members(guildID discord.Snowflake) ([]discord.Member, error) {
}
for _, m := range ms {
if err := s.Store.MemberSet(guildID, &m); err != nil {
if err := s.Store.MemberSet(guildID, m); err != nil {
return nil, err
}
}
return ms, s.Gateway.RequestGuildMembers(gateway.RequestGuildMembersData{
GuildID: []discord.Snowflake{guildID},
GuildID: []discord.GuildID{guildID},
Presences: true,
})
}
@ -387,31 +476,48 @@ func (s *State) Members(guildID discord.Snowflake) ([]discord.Member, error) {
////
func (s *State) Message(
channelID, messageID discord.Snowflake) (*discord.Message, error) {
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
m, err := s.Store.Message(channelID, messageID)
if err == nil {
return m, nil
}
var wg sync.WaitGroup
c, cerr := s.Store.Channel(channelID)
if cerr != nil {
wg.Add(1)
go func() {
c, cerr = s.Session.Channel(channelID)
if cerr == nil {
cerr = s.Store.ChannelSet(*c)
}
wg.Done()
}()
}
m, err = s.Session.Message(channelID, messageID)
if err != nil {
return nil, err
return nil, errors.Wrap(err, "unable to fetch message")
}
// Fill the GuildID, because Discord doesn't do it for us.
c, err := s.Channel(channelID)
if err == nil {
// If it's 0, it's 0 anyway. We don't need a check here.
m.GuildID = c.GuildID
wg.Wait()
if cerr != nil {
return nil, errors.Wrap(cerr, "unable to fetch channel")
}
return m, s.Store.MessageSet(m)
m.ChannelID = c.ID
m.GuildID = c.GuildID
return m, s.Store.MessageSet(*m)
}
// Messages fetches maximum 100 messages from the API, if it has to. There is no
// limit if it's from the State storage.
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
func (s *State) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
// TODO: Think of a design that doesn't rely on MaxMessages().
var maxMsgs = s.MaxMessages()
@ -440,7 +546,7 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
// New messages fetched weirdly does not have GuildID filled. We'll try and
// get it for consistency with incoming message creates.
var guildID discord.Snowflake
var guildID discord.GuildID
// A bit too convoluted, but whatever.
c, err := s.Channel(channelID)
@ -449,11 +555,13 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
guildID = c.GuildID
}
for i := range ms {
// Iterate in reverse, since the store is expected to prepend the latest
// messages.
for i := len(ms) - 1; i >= 0; i-- {
// Set the guild ID, fine if it's 0 (it's already 0 anyway).
ms[i].GuildID = guildID
if err := s.Store.MessageSet(&ms[i]); err != nil {
if err := s.Store.MessageSet(ms[i]); err != nil {
return nil, err
}
}
@ -476,14 +584,16 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
// Presence checks the state for user presences. If no guildID is given, it will
// look for the presence in all guilds.
func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) {
func (s *State) Presence(
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
p, err := s.Store.Presence(guildID, userID)
if err == nil {
return p, nil
}
// If there's no guild ID, look in all guilds
if !guildID.Valid() {
if !guildID.IsValid() {
g, err := s.Guilds()
if err != nil {
return nil, err
@ -501,8 +611,7 @@ func (s *State) Presence(guildID, userID discord.Snowflake) (*discord.Presence,
////
func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
func (s *State) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
r, err := s.Store.Role(guildID, roleID)
if err == nil {
return r, nil
@ -522,15 +631,19 @@ func (s *State) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
role = &r
}
if err := s.RoleSet(guildID, &r); err != nil {
if err := s.RoleSet(guildID, r); err != nil {
return role, err
}
}
if role == nil {
return nil, ErrStoreNotFound
}
return role, nil
}
func (s *State) Roles(guildID discord.Snowflake) ([]discord.Role, error) {
func (s *State) Roles(guildID discord.GuildID) ([]discord.Role, error) {
rs, err := s.Store.Roles(guildID)
if err == nil {
return rs, nil
@ -544,10 +657,30 @@ func (s *State) Roles(guildID discord.Snowflake) ([]discord.Role, error) {
for _, r := range rs {
r := r
if err := s.RoleSet(guildID, &r); err != nil {
if err := s.RoleSet(guildID, r); err != nil {
return rs, err
}
}
return rs, nil
}
func (s *State) fetchGuild(id discord.GuildID) (g *discord.Guild, err error) {
g, err = s.Session.Guild(id)
if err == nil {
err = s.Store.GuildSet(*g)
}
return
}
func (s *State) fetchMember(
guildID discord.GuildID, userID discord.UserID) (m *discord.Member, err error) {
m, err = s.Session.Member(guildID, userID)
if err == nil {
err = s.Store.MemberSet(guildID, *m)
}
return
}

View file

@ -1,71 +1,102 @@
package state
import (
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
)
func (s *State) hookSession() error {
s.unhooker = s.Session.AddHandler(func(iface interface{}) {
func (s *State) hookSession() {
s.Session.AddHandler(func(event interface{}) {
// Call the pre-handler before the state handler.
if s.PreHandler != nil {
s.PreHandler.Call(iface)
s.PreHandler.Call(event)
}
s.onEvent(iface)
s.Handler.Call(iface)
})
return nil
// Run the state handler.
s.onEvent(event)
switch event := event.(type) {
case *gateway.ReadyEvent:
s.Handler.Call(event)
s.handleReady(event)
case *gateway.GuildCreateEvent:
s.Handler.Call(event)
s.handleGuildCreate(event)
case *gateway.GuildDeleteEvent:
s.Handler.Call(event)
s.handleGuildDelete(event)
// https://github.com/discord/discord-api-docs/commit/01665c4
case *gateway.MessageCreateEvent:
if event.Member != nil {
event.Member.User = event.Author
}
s.Handler.Call(event)
case *gateway.MessageUpdateEvent:
if event.Member != nil {
event.Member.User = event.Author
}
s.Handler.Call(event)
default:
s.Handler.Call(event)
}
})
}
func (s *State) onEvent(iface interface{}) {
switch ev := iface.(type) {
case *gateway.ReadyEvent:
// Reset the store before proceeding.
if resetter, ok := s.Store.(StoreResetter); ok {
if err := resetter.Reset(); err != nil {
s.stateErr(err, "Failed to reset state on READY")
}
}
// Set Ready to the state
s.Ready = *ev
// Handle presences
for _, p := range ev.Presences {
p := p
if err := s.Store.PresenceSet(0, &p); err != nil {
s.stateErr(err, "Failed to set global presence")
if err := s.Store.PresenceSet(0, p); err != nil {
s.stateErr(err, "failed to set global presence")
}
}
// Handle guilds
for i := range ev.Guilds {
s.batchLog(handleGuildCreate(s.Store, &ev.Guilds[i])...)
s.batchLog(storeGuildCreate(s.Store, &ev.Guilds[i])...)
}
// Handle private channels
for i := range ev.PrivateChannels {
if err := s.Store.ChannelSet(&ev.PrivateChannels[i]); err != nil {
s.stateErr(err, "Failed to set channel in state")
for _, ch := range ev.PrivateChannels {
if err := s.Store.ChannelSet(ch); err != nil {
s.stateErr(err, "failed to set channel in state")
}
}
// Handle user
if err := s.Store.MyselfSet(&ev.User); err != nil {
s.stateErr(err, "Failed to set self in state")
if err := s.Store.MyselfSet(ev.User); err != nil {
s.stateErr(err, "failed to set self in state")
}
case *gateway.GuildCreateEvent:
s.batchLog(handleGuildCreate(s.Store, ev)...)
case *gateway.GuildUpdateEvent:
if err := s.Store.GuildSet((*discord.Guild)(ev)); err != nil {
s.stateErr(err, "Failed to update guild in state")
if err := s.Store.GuildSet(ev.Guild); err != nil {
s.stateErr(err, "failed to update guild in state")
}
case *gateway.GuildDeleteEvent:
if err := s.Store.GuildRemove(ev.ID); err != nil {
s.stateErr(err, "Failed to delete guild in state")
if err := s.Store.GuildRemove(ev.ID); err != nil && !ev.Unavailable {
s.stateErr(err, "failed to delete guild in state")
}
case *gateway.GuildMemberAddEvent:
if err := s.Store.MemberSet(ev.GuildID, &ev.Member); err != nil {
s.stateErr(err, "Failed to add a member in state")
if err := s.Store.MemberSet(ev.GuildID, ev.Member); err != nil {
s.stateErr(err, "failed to add a member in state")
}
case *gateway.GuildMemberUpdateEvent:
@ -78,89 +109,85 @@ func (s *State) onEvent(iface interface{}) {
// Update available fields from ev into m
ev.Update(m)
if err := s.Store.MemberSet(ev.GuildID, m); err != nil {
s.stateErr(err, "Failed to update a member in state")
if err := s.Store.MemberSet(ev.GuildID, *m); err != nil {
s.stateErr(err, "failed to update a member in state")
}
case *gateway.GuildMemberRemoveEvent:
if err := s.Store.MemberRemove(ev.GuildID, ev.User.ID); err != nil {
s.stateErr(err, "Failed to remove a member in state")
s.stateErr(err, "failed to remove a member in state")
}
case *gateway.GuildMembersChunkEvent:
for _, m := range ev.Members {
m := m
if err := s.Store.MemberSet(ev.GuildID, &m); err != nil {
s.stateErr(err, "Failed to add a member from chunk in state")
if err := s.Store.MemberSet(ev.GuildID, m); err != nil {
s.stateErr(err, "failed to add a member from chunk in state")
}
}
for _, p := range ev.Presences {
p := p
if err := s.Store.PresenceSet(ev.GuildID, &p); err != nil {
s.stateErr(err, "Failed to add a presence from chunk in state")
if err := s.Store.PresenceSet(ev.GuildID, p); err != nil {
s.stateErr(err, "failed to add a presence from chunk in state")
}
}
case *gateway.GuildRoleCreateEvent:
if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil {
s.stateErr(err, "Failed to add a role in state")
if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil {
s.stateErr(err, "failed to add a role in state")
}
case *gateway.GuildRoleUpdateEvent:
if err := s.Store.RoleSet(ev.GuildID, &ev.Role); err != nil {
s.stateErr(err, "Failed to update a role in state")
if err := s.Store.RoleSet(ev.GuildID, ev.Role); err != nil {
s.stateErr(err, "failed to update a role in state")
}
case *gateway.GuildRoleDeleteEvent:
if err := s.Store.RoleRemove(ev.GuildID, ev.RoleID); err != nil {
s.stateErr(err, "Failed to remove a role in state")
s.stateErr(err, "failed to remove a role in state")
}
case *gateway.GuildEmojisUpdateEvent:
if err := s.Store.EmojiSet(ev.GuildID, ev.Emojis); err != nil {
s.stateErr(err, "Failed to update emojis in state")
s.stateErr(err, "failed to update emojis in state")
}
case *gateway.ChannelCreateEvent:
if err := s.Store.ChannelSet((*discord.Channel)(ev)); err != nil {
s.stateErr(err, "Failed to create a channel in state")
if err := s.Store.ChannelSet(ev.Channel); err != nil {
s.stateErr(err, "failed to create a channel in state")
}
case *gateway.ChannelUpdateEvent:
if err := s.Store.ChannelSet((*discord.Channel)(ev)); err != nil {
s.stateErr(err, "Failed to update a channel in state")
if err := s.Store.ChannelSet(ev.Channel); err != nil {
s.stateErr(err, "failed to update a channel in state")
}
case *gateway.ChannelDeleteEvent:
if err := s.Store.ChannelRemove((*discord.Channel)(ev)); err != nil {
s.stateErr(err, "Failed to remove a channel in state")
if err := s.Store.ChannelRemove(ev.Channel); err != nil {
s.stateErr(err, "failed to remove a channel in state")
}
case *gateway.ChannelPinsUpdateEvent:
// not tracked.
case *gateway.MessageCreateEvent:
if err := s.Store.MessageSet(&ev.Message); err != nil {
s.stateErr(err, "Failed to add a message in state")
if err := s.Store.MessageSet(ev.Message); err != nil {
s.stateErr(err, "failed to add a message in state")
}
case *gateway.MessageUpdateEvent:
if err := s.Store.MessageSet(&ev.Message); err != nil {
s.stateErr(err, "Failed to update a message in state")
if err := s.Store.MessageSet(ev.Message); err != nil {
s.stateErr(err, "failed to update a message in state")
}
case *gateway.MessageDeleteEvent:
if err := s.Store.MessageRemove(ev.ChannelID, ev.ID); err != nil {
s.stateErr(err, "Failed to delete a message in state")
s.stateErr(err, "failed to delete a message in state")
}
case *gateway.MessageDeleteBulkEvent:
for _, id := range ev.IDs {
if err := s.Store.MessageRemove(ev.ChannelID, id); err != nil {
s.stateErr(err, "Failed to delete bulk meessages in state")
s.stateErr(err, "failed to delete bulk messages in state")
}
}
@ -224,16 +251,14 @@ func (s *State) onEvent(iface interface{}) {
})
case *gateway.PresenceUpdateEvent:
if err := s.Store.PresenceSet(ev.GuildID, &ev.Presence); err != nil {
s.stateErr(err, "Failed to update presence in state")
if err := s.Store.PresenceSet(ev.GuildID, ev.Presence); err != nil {
s.stateErr(err, "failed to update presence in state")
}
case *gateway.PresencesReplaceEvent:
for i := range *ev {
p := (*ev)[i]
if err := s.Store.PresenceSet(p.GuildID, &p); err != nil {
s.stateErr(err, "Failed to update presence in state")
for _, p := range *ev {
if err := s.Store.PresenceSet(p.GuildID, p); err != nil {
s.stateErr(err, "failed to update presence in state")
}
}
@ -253,19 +278,19 @@ func (s *State) onEvent(iface interface{}) {
s.Ready.Notes[ev.ID] = ev.Note
case *gateway.UserUpdateEvent:
if err := s.Store.MyselfSet(&ev.User); err != nil {
s.stateErr(err, "Failed to update myself from USER_UPDATE")
if err := s.Store.MyselfSet(ev.User); err != nil {
s.stateErr(err, "failed to update myself from USER_UPDATE")
}
case *gateway.VoiceStateUpdateEvent:
vs := &ev.VoiceState
if vs.ChannelID == 0 {
if err := s.Store.VoiceStateRemove(vs.GuildID, vs.UserID); err != nil {
s.stateErr(err, "Failed to remove voice state from state")
s.stateErr(err, "failed to remove voice state from state")
}
} else {
if err := s.Store.VoiceStateSet(vs.GuildID, vs); err != nil {
s.stateErr(err, "Failed to update voice state in state")
if err := s.Store.VoiceStateSet(vs.GuildID, *vs); err != nil {
s.stateErr(err, "failed to update voice state in state")
}
}
}
@ -282,7 +307,7 @@ func (s *State) batchLog(errors ...error) {
// Helper functions
func (s *State) editMessage(ch, msg discord.Snowflake, fn func(m *discord.Message) bool) {
func (s *State) editMessage(ch discord.ChannelID, msg discord.MessageID, fn func(m *discord.Message) bool) {
m, err := s.Store.Message(ch, msg)
if err != nil {
return
@ -290,8 +315,8 @@ func (s *State) editMessage(ch, msg discord.Snowflake, fn func(m *discord.Messag
if !fn(m) {
return
}
if err := s.Store.MessageSet(m); err != nil {
s.stateErr(err, "Failed to save message in reaction add")
if err := s.Store.MessageSet(*m); err != nil {
s.stateErr(err, "failed to save message in reaction add")
}
}
@ -304,51 +329,52 @@ func findReaction(rs []discord.Reaction, emoji discord.Emoji) int {
return -1
}
func handleGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error {
// If a guild is unavailable, don't populate it in the state, as the guild
// data is very incomplete.
func storeGuildCreate(store Store, guild *gateway.GuildCreateEvent) []error {
if guild.Unavailable {
return nil
}
stack, error := newErrorStack()
stack, errs := newErrorStack()
if err := store.GuildSet(&guild.Guild); err != nil {
error(err, "Failed to set guild in Ready")
if err := store.GuildSet(guild.Guild); err != nil {
errs(err, "failed to set guild in Ready")
}
// Handle guild emojis
if guild.Emojis != nil {
if err := store.EmojiSet(guild.ID, guild.Emojis); err != nil {
error(err, "Failed to set guild emojis")
errs(err, "failed to set guild emojis")
}
}
// Handle guild member
for i := range guild.Members {
if err := store.MemberSet(guild.ID, &guild.Members[i]); err != nil {
error(err, "Failed to set guild member in Ready")
for _, m := range guild.Members {
if err := store.MemberSet(guild.ID, m); err != nil {
errs(err, "failed to set guild member in Ready")
}
}
// Handle guild channels
for i := range guild.Channels {
if err := store.ChannelSet(&guild.Channels[i]); err != nil {
error(err, "Failed to set guild channel in Ready")
for _, ch := range guild.Channels {
// I HATE Discord.
ch.GuildID = guild.ID
if err := store.ChannelSet(ch); err != nil {
errs(err, "failed to set guild channel in Ready")
}
}
// Handle guild presences
for i := range guild.Presences {
if err := store.PresenceSet(guild.ID, &guild.Presences[i]); err != nil {
error(err, "Failed to set guild presence in Ready")
for _, p := range guild.Presences {
if err := store.PresenceSet(guild.ID, p); err != nil {
errs(err, "failed to set guild presence in Ready")
}
}
// Handle guild voice states
for i := range guild.VoiceStates {
if err := store.VoiceStateSet(guild.ID, &guild.VoiceStates[i]); err != nil {
error(err, "Failed to set guild voice state in Ready")
for _, v := range guild.VoiceStates {
if err := store.VoiceStateSet(guild.ID, v); err != nil {
errs(err, "failed to set guild voice state in Ready")
}
}

View file

@ -21,69 +21,81 @@ type Store interface {
// would mutate the underlying slice (and as a result the returned slice as
// well). The best way to avoid this is to copy the whole slice, like
// DefaultStore does.
//
// These methods should not care about returning slices in order, unless
// explicitly stated against.
type StoreGetter interface {
Me() (*discord.User, error)
// Channel should check for both DM and guild channels.
Channel(id discord.Snowflake) (*discord.Channel, error)
Channels(guildID discord.Snowflake) ([]discord.Channel, error)
Channel(id discord.ChannelID) (*discord.Channel, error)
Channels(guildID discord.GuildID) ([]discord.Channel, error)
// same API as (*api.Client)
CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error)
CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error)
PrivateChannels() ([]discord.Channel, error)
Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error)
Emojis(guildID discord.Snowflake) ([]discord.Emoji, error)
Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error)
Emojis(guildID discord.GuildID) ([]discord.Emoji, error)
Guild(id discord.Snowflake) (*discord.Guild, error)
Guild(id discord.GuildID) (*discord.Guild, error)
Guilds() ([]discord.Guild, error)
Member(guildID, userID discord.Snowflake) (*discord.Member, error)
Members(guildID discord.Snowflake) ([]discord.Member, error)
Member(guildID discord.GuildID, userID discord.UserID) (*discord.Member, error)
Members(guildID discord.GuildID) ([]discord.Member, error)
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
Message(channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error)
// Messages should return messages ordered from latest to earliest.
Messages(channelID discord.ChannelID) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.
// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
Presences(guildID discord.Snowflake) ([]discord.Presence, error)
Presence(guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error)
Presences(guildID discord.GuildID) ([]discord.Presence, error)
Role(guildID, roleID discord.Snowflake) (*discord.Role, error)
Roles(guildID discord.Snowflake) ([]discord.Role, error)
Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error)
Roles(guildID discord.GuildID) ([]discord.Role, error)
VoiceState(guildID discord.Snowflake, userID discord.Snowflake) (*discord.VoiceState, error)
VoiceStates(guildID discord.Snowflake) ([]discord.VoiceState, error)
VoiceState(guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error)
VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error)
}
type StoreModifier interface {
MyselfSet(me *discord.User) error
MyselfSet(me discord.User) error
// ChannelSet should switch on Type to know if it's a private channel or
// not.
ChannelSet(*discord.Channel) error
ChannelRemove(*discord.Channel) error
ChannelSet(discord.Channel) error
ChannelRemove(discord.Channel) error
// EmojiSet should delete all old emojis before setting new ones.
EmojiSet(guildID discord.Snowflake, emojis []discord.Emoji) error
EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error
GuildSet(*discord.Guild) error
GuildRemove(id discord.Snowflake) error
GuildSet(discord.Guild) error
GuildRemove(id discord.GuildID) error
MemberSet(guildID discord.Snowflake, member *discord.Member) error
MemberRemove(guildID, userID discord.Snowflake) error
MemberSet(guildID discord.GuildID, member discord.Member) error
MemberRemove(guildID discord.GuildID, userID discord.UserID) error
MessageSet(*discord.Message) error
MessageRemove(channelID, messageID discord.Snowflake) error
// MessageSet should prepend messages into the slice, the latest being in
// front.
MessageSet(discord.Message) error
MessageRemove(channelID discord.ChannelID, messageID discord.MessageID) error
PresenceSet(guildID discord.Snowflake, presence *discord.Presence) error
PresenceRemove(guildID, userID discord.Snowflake) error
PresenceSet(guildID discord.GuildID, presence discord.Presence) error
PresenceRemove(guildID discord.GuildID, userID discord.UserID) error
RoleSet(guildID discord.Snowflake, role *discord.Role) error
RoleRemove(guildID, roleID discord.Snowflake) error
RoleSet(guildID discord.GuildID, role discord.Role) error
RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error
VoiceStateSet(guildID discord.Snowflake, voiceState *discord.VoiceState) error
VoiceStateRemove(guildID discord.Snowflake, userID discord.Snowflake) error
VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error
VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error
}
// StoreResetter is used by the state to reset the store on every Ready event.
type StoreResetter interface {
// Reset resets the store to a new valid instance.
Reset() error
}
// ErrStoreNotFound is an error that a store can use to return when something
@ -97,7 +109,7 @@ func DiffMessage(src discord.Message, dst *discord.Message) {
if src.Content != "" {
dst.Content = src.Content
}
if src.EditedTimestamp.Valid() {
if src.EditedTimestamp.IsValid() {
dst.EditedTimestamp = src.EditedTimestamp
}
if src.Mentions != nil {
@ -109,10 +121,10 @@ func DiffMessage(src discord.Message, dst *discord.Message) {
if src.Attachments != nil {
dst.Attachments = src.Attachments
}
if src.Timestamp.Valid() {
if src.Timestamp.IsValid() {
dst.Timestamp = src.Timestamp
}
if src.Author.ID.Valid() {
if src.Author.ID.IsValid() {
dst.Author = src.Author
}
if src.Reactions != nil {

View file

@ -1,7 +1,6 @@
package state
import (
"sort"
"sync"
"github.com/diamondburned/arikawa/discord"
@ -10,21 +9,25 @@ import (
// TODO: make an ExpiryStore
type DefaultStore struct {
*DefaultStoreOptions
DefaultStoreOptions
self discord.User
// includes normal and private
privates map[discord.Snowflake]*discord.Channel // channelID:channel
guilds map[discord.Snowflake]*discord.Guild // guildID:guild
privates map[discord.ChannelID]discord.Channel
guilds map[discord.GuildID]discord.Guild
channels map[discord.Snowflake][]discord.Channel // guildID:channels
members map[discord.Snowflake][]discord.Member // guildID:members
presences map[discord.Snowflake][]discord.Presence // guildID:presences
messages map[discord.Snowflake][]discord.Message // channelID:messages
voiceStates map[discord.Snowflake][]discord.VoiceState // guildID:voiceStates
roles map[discord.GuildID][]discord.Role
emojis map[discord.GuildID][]discord.Emoji
channels map[discord.GuildID][]discord.Channel
presences map[discord.GuildID][]discord.Presence
voiceStates map[discord.GuildID][]discord.VoiceState
messages map[discord.ChannelID][]discord.Message
mut sync.Mutex
// special case; optimize for lots of members
members map[discord.GuildID]map[discord.UserID]discord.Member
mut sync.RWMutex
}
type DefaultStoreOptions struct {
@ -40,9 +43,7 @@ func NewDefaultStore(opts *DefaultStoreOptions) *DefaultStore {
}
}
ds := &DefaultStore{
DefaultStoreOptions: opts,
}
ds := &DefaultStore{DefaultStoreOptions: *opts}
ds.Reset()
return ds
@ -54,14 +55,17 @@ func (s *DefaultStore) Reset() error {
s.self = discord.User{}
s.privates = map[discord.Snowflake]*discord.Channel{}
s.guilds = map[discord.Snowflake]*discord.Guild{}
s.privates = map[discord.ChannelID]discord.Channel{}
s.guilds = map[discord.GuildID]discord.Guild{}
s.channels = map[discord.Snowflake][]discord.Channel{}
s.members = map[discord.Snowflake][]discord.Member{}
s.presences = map[discord.Snowflake][]discord.Presence{}
s.messages = map[discord.Snowflake][]discord.Message{}
s.voiceStates = map[discord.Snowflake][]discord.VoiceState{}
s.roles = map[discord.GuildID][]discord.Role{}
s.emojis = map[discord.GuildID][]discord.Emoji{}
s.channels = map[discord.GuildID][]discord.Channel{}
s.presences = map[discord.GuildID][]discord.Presence{}
s.voiceStates = map[discord.GuildID][]discord.VoiceState{}
s.messages = map[discord.ChannelID][]discord.Message{}
s.members = map[discord.GuildID]map[discord.UserID]discord.Member{}
return nil
}
@ -69,19 +73,19 @@ func (s *DefaultStore) Reset() error {
////
func (s *DefaultStore) Me() (*discord.User, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
if !s.self.ID.Valid() {
if !s.self.ID.IsValid() {
return nil, ErrStoreNotFound
}
return &s.self, nil
}
func (s *DefaultStore) MyselfSet(me *discord.User) error {
func (s *DefaultStore) MyselfSet(me discord.User) error {
s.mut.Lock()
s.self = *me
s.self = me
s.mut.Unlock()
return nil
@ -89,12 +93,13 @@ func (s *DefaultStore) MyselfSet(me *discord.User) error {
////
func (s *DefaultStore) Channel(id discord.Snowflake) (*discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
if ch, ok := s.privates[id]; ok {
return ch, nil
// implicit copy
return &ch, nil
}
for _, chs := range s.channels {
@ -108,9 +113,9 @@ func (s *DefaultStore) Channel(id discord.Snowflake) (*discord.Channel, error) {
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Channels(guildID discord.Snowflake) ([]discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Channels(guildID discord.GuildID) ([]discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
chs, ok := s.channels[guildID]
if !ok {
@ -122,17 +127,18 @@ func (s *DefaultStore) Channels(guildID discord.Snowflake) ([]discord.Channel, e
// CreatePrivateChannel searches in the cache for a private channel. It makes no
// API calls.
func (s *DefaultStore) CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) CreatePrivateChannel(recipient discord.UserID) (*discord.Channel, error) {
s.mut.RLock()
defer s.mut.RUnlock()
// slow way
for _, ch := range s.privates {
if ch.Type != discord.DirectMessage || len(ch.DMRecipients) < 1 {
if ch.Type != discord.DirectMessage || len(ch.DMRecipients) == 0 {
continue
}
if ch.DMRecipients[0].ID == recipient {
return &(*ch), nil
// Return an implicit copy made by range.
return &ch, nil
}
}
return nil, ErrStoreNotFound
@ -140,22 +146,22 @@ func (s *DefaultStore) CreatePrivateChannel(recipient discord.Snowflake) (*disco
// PrivateChannels returns a list of Direct Message channels randomly ordered.
func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) {
s.mut.Lock()
defer s.mut.Unlock()
s.mut.RLock()
defer s.mut.RUnlock()
var chs = make([]discord.Channel, 0, len(s.privates))
for _, ch := range s.privates {
chs = append(chs, *ch)
for i := range s.privates {
chs = append(chs, s.privates[i])
}
return chs, nil
}
func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
func (s *DefaultStore) ChannelSet(channel discord.Channel) error {
s.mut.Lock()
defer s.mut.Unlock()
if !channel.GuildID.Valid() {
if !channel.GuildID.IsValid() {
s.privates[channel.ID] = channel
} else {
@ -169,20 +175,20 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
}
// Found, just edit
chs[i] = *channel
chs[i] = channel
return nil
}
}
chs = append(chs, *channel)
chs = append(chs, channel)
s.channels[channel.GuildID] = chs
}
return nil
}
func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
func (s *DefaultStore) ChannelRemove(channel discord.Channel) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -193,9 +199,11 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
for i, ch := range chs {
if ch.ID == channel.ID {
chs = append(chs[:i], chs[i+1:]...)
s.channels[channel.GuildID] = chs
// Fast unordered delete.
chs[i] = chs[len(chs)-1]
chs = chs[:len(chs)-1]
s.channels[channel.GuildID] = chs
return nil
}
}
@ -205,17 +213,18 @@ func (s *DefaultStore) ChannelRemove(channel *discord.Channel) error {
////
func (s *DefaultStore) Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Emoji(guildID discord.GuildID, emojiID discord.EmojiID) (*discord.Emoji, error) {
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
emojis, ok := s.emojis[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, emoji := range gd.Emojis {
for _, emoji := range emojis {
if emoji.ID == emojiID {
// Emoji is an implicit copy, so we could do this safely.
return &emoji, nil
}
}
@ -223,169 +232,133 @@ func (s *DefaultStore) Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Emojis(guildID discord.Snowflake) ([]discord.Emoji, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Emojis(guildID discord.GuildID) ([]discord.Emoji, error) {
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
emojis, ok := s.emojis[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Emoji{}, gd.Emojis...), nil
return append([]discord.Emoji{}, emojis...), nil
}
func (s *DefaultStore) EmojiSet(guildID discord.Snowflake, emojis []discord.Emoji) error {
func (s *DefaultStore) EmojiSet(guildID discord.GuildID, emojis []discord.Emoji) error {
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
if !ok {
return ErrStoreNotFound
}
// A nil slice is acceptable, as we'll make a new slice later on and set it.
s.emojis[guildID] = emojis
filtered := emojis[:0]
Main:
for _, enew := range emojis {
// Try and see if this emoji is already in the slice
for i, emoji := range gd.Emojis {
if emoji.ID == enew.ID {
// If it is, we simply replace it
gd.Emojis[i] = enew
continue Main
}
}
// If not, we add it to the slice that's to be appended.
filtered = append(filtered, enew)
}
// Append the new emojis
gd.Emojis = append(gd.Emojis, filtered...)
return nil
}
////
func (s *DefaultStore) Guild(id discord.Snowflake) (*discord.Guild, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Guild(id discord.GuildID) (*discord.Guild, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ch, ok := s.guilds[id]
if !ok {
return nil, ErrStoreNotFound
}
return ch, nil
// implicit copy
return &ch, nil
}
func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
s.mut.Lock()
s.mut.RLock()
defer s.mut.RUnlock()
if len(s.guilds) == 0 {
s.mut.Unlock()
return nil, ErrStoreNotFound
}
var gs = make([]discord.Guild, 0, len(s.guilds))
for _, g := range s.guilds {
gs = append(gs, *g)
gs = append(gs, g)
}
s.mut.Unlock()
sort.Slice(gs, func(i, j int) bool {
return gs[i].ID > gs[j].ID
})
return gs, nil
}
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
func (s *DefaultStore) GuildSet(guild discord.Guild) error {
s.mut.Lock()
defer s.mut.Unlock()
if g, ok := s.guilds[guild.ID]; ok {
// preserve state stuff
if guild.Roles == nil {
guild.Roles = g.Roles
}
if guild.Emojis == nil {
guild.Emojis = g.Emojis
}
}
s.guilds[guild.ID] = guild
return nil
}
func (s *DefaultStore) GuildRemove(id discord.Snowflake) error {
func (s *DefaultStore) GuildRemove(id discord.GuildID) error {
s.mut.Lock()
delete(s.guilds, id)
s.mut.Unlock()
defer s.mut.Unlock()
if _, ok := s.guilds[id]; !ok {
return ErrStoreNotFound
}
delete(s.guilds, id)
return nil
}
////
func (s *DefaultStore) Member(guildID, userID discord.Snowflake) (*discord.Member, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Member(
guildID discord.GuildID, userID discord.UserID) (*discord.Member, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.members[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, m := range ms {
if m.User.ID == userID {
return &m, nil
}
m, ok := ms[userID]
if ok {
return &m, nil
}
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Members(guildID discord.Snowflake) ([]discord.Member, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Members(guildID discord.GuildID) ([]discord.Member, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.members[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Member{}, ms...), nil
var members = make([]discord.Member, 0, len(ms))
for _, m := range ms {
members = append(members, m)
}
return members, nil
}
func (s *DefaultStore) MemberSet(guildID discord.Snowflake, member *discord.Member) error {
func (s *DefaultStore) MemberSet(guildID discord.GuildID, member discord.Member) error {
s.mut.Lock()
defer s.mut.Unlock()
ms := s.members[guildID]
// Try and see if this member is already in the slice
for i, m := range ms {
if m.User.ID == member.User.ID {
// If it is, we simply replace it
ms[i] = *member
s.members[guildID] = ms
return nil
}
ms, ok := s.members[guildID]
if !ok {
ms = make(map[discord.UserID]discord.Member, 1)
}
// Append the new member
ms = append(ms, *member)
ms[member.User.ID] = member
s.members[guildID] = ms
return nil
}
func (s *DefaultStore) MemberRemove(guildID, userID discord.Snowflake) error {
func (s *DefaultStore) MemberRemove(guildID discord.GuildID, userID discord.UserID) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -394,24 +367,21 @@ func (s *DefaultStore) MemberRemove(guildID, userID discord.Snowflake) error {
return ErrStoreNotFound
}
// Try and see if this member is already in the slice
for i, m := range ms {
if m.User.ID == userID {
ms = append(ms, ms[i+1:]...)
s.members[guildID] = ms
return nil
}
if _, ok := ms[userID]; !ok {
return ErrStoreNotFound
}
return ErrStoreNotFound
delete(ms, userID)
return nil
}
////
func (s *DefaultStore) Message(channelID, messageID discord.Snowflake) (*discord.Message, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Message(
channelID discord.ChannelID, messageID discord.MessageID) (*discord.Message, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.messages[channelID]
if !ok {
@ -427,25 +397,23 @@ func (s *DefaultStore) Message(channelID, messageID discord.Snowflake) (*discord
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Messages(channelID discord.ChannelID) ([]discord.Message, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ms, ok := s.messages[channelID]
if !ok {
return nil, ErrStoreNotFound
}
cp := make([]discord.Message, len(ms))
copy(cp, ms)
return cp, nil
return append([]discord.Message{}, ms...), nil
}
func (s *DefaultStore) MaxMessages() int {
return int(s.DefaultStoreOptions.MaxMessages)
}
func (s *DefaultStore) MessageSet(message *discord.Message) error {
func (s *DefaultStore) MessageSet(message discord.Message) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -457,30 +425,38 @@ func (s *DefaultStore) MessageSet(message *discord.Message) error {
// Check if we already have the message.
for i, m := range ms {
if m.ID == message.ID {
DiffMessage(*message, &m)
DiffMessage(message, &m)
ms[i] = m
return nil
}
}
// Prepend the latest message at the end
if end := s.MaxMessages(); len(ms) >= end {
// Copy hack to prepend. This copies the 0th-(end-1)th entries to
// 1st-endth.
copy(ms[1:end], ms[0:end-1])
// Then, set the 0th entry.
ms[0] = *message
// Order: latest to earliest, similar to the API.
var end = len(ms)
if max := s.MaxMessages(); end >= max {
// If the end (length) is larger than the maximum amount, then cap it.
end = max
} else {
ms = append(ms, *message)
// Else, append an empty message to the end.
ms = append(ms, discord.Message{})
// Increment to update the length.
end++
}
// Copy hack to prepend. This copies the 0th-(end-1)th entries to
// 1st-endth.
copy(ms[1:end], ms[0:end-1])
// Then, set the 0th entry.
ms[0] = message
s.messages[message.ChannelID] = ms
return nil
}
func (s *DefaultStore) MessageRemove(channelID, messageID discord.Snowflake) error {
func (s *DefaultStore) MessageRemove(
channelID discord.ChannelID, messageID discord.MessageID) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -502,9 +478,11 @@ func (s *DefaultStore) MessageRemove(channelID, messageID discord.Snowflake) err
////
func (s *DefaultStore) Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Presence(
guildID discord.GuildID, userID discord.UserID) (*discord.Presence, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ps, ok := s.presences[guildID]
if !ok {
@ -520,39 +498,38 @@ func (s *DefaultStore) Presence(guildID, userID discord.Snowflake) (*discord.Pre
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Presences(guildID discord.Snowflake) ([]discord.Presence, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Presences(guildID discord.GuildID) ([]discord.Presence, error) {
s.mut.RLock()
defer s.mut.RUnlock()
ps, ok := s.presences[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return ps, nil
return append([]discord.Presence{}, ps...), nil
}
func (s *DefaultStore) PresenceSet(guildID discord.Snowflake, presence *discord.Presence) error {
func (s *DefaultStore) PresenceSet(guildID discord.GuildID, presence discord.Presence) error {
s.mut.Lock()
defer s.mut.Unlock()
ps := s.presences[guildID]
ps, _ := s.presences[guildID]
for i, p := range ps {
if p.User.ID == presence.User.ID {
ps[i] = *presence
s.presences[guildID] = ps
// Change the backing array.
ps[i] = presence
return nil
}
}
ps = append(ps, *presence)
ps = append(ps, presence)
s.presences[guildID] = ps
return nil
}
func (s *DefaultStore) PresenceRemove(guildID, userID discord.Snowflake) error {
func (s *DefaultStore) PresenceRemove(guildID discord.GuildID, userID discord.UserID) error {
s.mut.Lock()
defer s.mut.Unlock()
@ -563,9 +540,10 @@ func (s *DefaultStore) PresenceRemove(guildID, userID discord.Snowflake) error {
for i, p := range ps {
if p.User.ID == userID {
ps = append(ps[:i], ps[i+1:]...)
s.presences[guildID] = ps
ps[i] = ps[len(ps)-1]
ps = ps[:len(ps)-1]
s.presences[guildID] = ps
return nil
}
}
@ -575,16 +553,16 @@ func (s *DefaultStore) PresenceRemove(guildID, userID discord.Snowflake) error {
////
func (s *DefaultStore) Role(guildID, roleID discord.Snowflake) (*discord.Role, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Role(guildID discord.GuildID, roleID discord.RoleID) (*discord.Role, error) {
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return nil, ErrStoreNotFound
}
for _, r := range gd.Roles {
for _, r := range rs {
if r.ID == roleID {
return &r, nil
}
@ -593,50 +571,55 @@ func (s *DefaultStore) Role(guildID, roleID discord.Snowflake) (*discord.Role, e
return nil, ErrStoreNotFound
}
func (s *DefaultStore) Roles(guildID discord.Snowflake) ([]discord.Role, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) Roles(guildID discord.GuildID) ([]discord.Role, error) {
s.mut.RLock()
defer s.mut.RUnlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return nil, ErrStoreNotFound
}
return append([]discord.Role{}, gd.Roles...), nil
return append([]discord.Role{}, rs...), nil
}
func (s *DefaultStore) RoleSet(guildID discord.Snowflake, role *discord.Role) error {
func (s *DefaultStore) RoleSet(guildID discord.GuildID, role discord.Role) error {
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
if !ok {
return ErrStoreNotFound
}
// A nil slice is fine, since we can just append the role.
rs, _ := s.roles[guildID]
for i, r := range gd.Roles {
for i, r := range rs {
if r.ID == role.ID {
gd.Roles[i] = *role
// This changes the backing array, so we don't need to reset the
// slice.
rs[i] = role
return nil
}
}
gd.Roles = append(gd.Roles, *role)
rs = append(rs, role)
s.roles[guildID] = rs
return nil
}
func (s *DefaultStore) RoleRemove(guildID, roleID discord.Snowflake) error {
func (s *DefaultStore) RoleRemove(guildID discord.GuildID, roleID discord.RoleID) error {
s.mut.Lock()
defer s.mut.Unlock()
gd, ok := s.guilds[guildID]
rs, ok := s.roles[guildID]
if !ok {
return ErrStoreNotFound
}
for i, r := range gd.Roles {
for i, r := range rs {
if r.ID == roleID {
gd.Roles = append(gd.Roles[:i], gd.Roles[i+1:]...)
// Fast delete.
rs[i] = rs[len(rs)-1]
rs = rs[:len(rs)-1]
s.roles[guildID] = rs
return nil
}
}
@ -646,9 +629,11 @@ func (s *DefaultStore) RoleRemove(guildID, roleID discord.Snowflake) error {
////
func (s *DefaultStore) VoiceState(guildID, userID discord.Snowflake) (*discord.VoiceState, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) VoiceState(
guildID discord.GuildID, userID discord.UserID) (*discord.VoiceState, error) {
s.mut.RLock()
defer s.mut.RUnlock()
states, ok := s.voiceStates[guildID]
if !ok {
@ -664,9 +649,9 @@ func (s *DefaultStore) VoiceState(guildID, userID discord.Snowflake) (*discord.V
return nil, ErrStoreNotFound
}
func (s *DefaultStore) VoiceStates(guildID discord.Snowflake) ([]discord.VoiceState, error) {
s.mut.Lock()
defer s.mut.Unlock()
func (s *DefaultStore) VoiceStates(guildID discord.GuildID) ([]discord.VoiceState, error) {
s.mut.RLock()
defer s.mut.RUnlock()
states, ok := s.voiceStates[guildID]
if !ok {
@ -676,27 +661,26 @@ func (s *DefaultStore) VoiceStates(guildID discord.Snowflake) ([]discord.VoiceSt
return append([]discord.VoiceState{}, states...), nil
}
func (s *DefaultStore) VoiceStateSet(guildID discord.Snowflake, voiceState *discord.VoiceState) error {
func (s *DefaultStore) VoiceStateSet(guildID discord.GuildID, voiceState discord.VoiceState) error {
s.mut.Lock()
defer s.mut.Unlock()
states := s.voiceStates[guildID]
states, _ := s.voiceStates[guildID]
for i, vs := range states {
if vs.UserID == voiceState.UserID {
states[i] = *voiceState
s.voiceStates[guildID] = states
// change the backing array
states[i] = voiceState
return nil
}
}
states = append(states, *voiceState)
states = append(states, voiceState)
s.voiceStates[guildID] = states
return nil
}
func (s *DefaultStore) VoiceStateRemove(guildID, userID discord.Snowflake) error {
func (s *DefaultStore) VoiceStateRemove(guildID discord.GuildID, userID discord.UserID) error {
s.mut.Lock()
defer s.mut.Unlock()

View file

@ -13,7 +13,7 @@ type NoopStore struct{}
var _ Store = (*NoopStore)(nil)
var ErrNotImplemented = errors.New("State is not implemented")
var ErrNotImplemented = errors.New("state is not implemented")
func (NoopStore) Reset() error {
return nil
@ -23,19 +23,19 @@ func (NoopStore) Me() (*discord.User, error) {
return nil, ErrNotImplemented
}
func (NoopStore) MyselfSet(*discord.User) error {
func (NoopStore) MyselfSet(discord.User) error {
return nil
}
func (NoopStore) Channel(discord.Snowflake) (*discord.Channel, error) {
func (NoopStore) Channel(discord.ChannelID) (*discord.Channel, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Channels(discord.Snowflake) ([]discord.Channel, error) {
func (NoopStore) Channels(discord.GuildID) ([]discord.Channel, error) {
return nil, ErrNotImplemented
}
func (NoopStore) CreatePrivateChannel(discord.Snowflake) (*discord.Channel, error) {
func (NoopStore) CreatePrivateChannel(discord.UserID) (*discord.Channel, error) {
return nil, ErrNotImplemented
}
@ -43,27 +43,27 @@ func (NoopStore) PrivateChannels() ([]discord.Channel, error) {
return nil, ErrNotImplemented
}
func (NoopStore) ChannelSet(*discord.Channel) error {
func (NoopStore) ChannelSet(discord.Channel) error {
return nil
}
func (NoopStore) ChannelRemove(*discord.Channel) error {
func (NoopStore) ChannelRemove(discord.Channel) error {
return nil
}
func (NoopStore) Emoji(_, _ discord.Snowflake) (*discord.Emoji, error) {
func (NoopStore) Emoji(discord.GuildID, discord.EmojiID) (*discord.Emoji, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Emojis(discord.Snowflake) ([]discord.Emoji, error) {
func (NoopStore) Emojis(discord.GuildID) ([]discord.Emoji, error) {
return nil, ErrNotImplemented
}
func (NoopStore) EmojiSet(discord.Snowflake, []discord.Emoji) error {
func (NoopStore) EmojiSet(discord.GuildID, []discord.Emoji) error {
return nil
}
func (NoopStore) Guild(discord.Snowflake) (*discord.Guild, error) {
func (NoopStore) Guild(discord.GuildID) (*discord.Guild, error) {
return nil, ErrNotImplemented
}
@ -71,35 +71,35 @@ func (NoopStore) Guilds() ([]discord.Guild, error) {
return nil, ErrNotImplemented
}
func (NoopStore) GuildSet(*discord.Guild) error {
func (NoopStore) GuildSet(discord.Guild) error {
return nil
}
func (NoopStore) GuildRemove(discord.Snowflake) error {
func (NoopStore) GuildRemove(discord.GuildID) error {
return nil
}
func (NoopStore) Member(_, _ discord.Snowflake) (*discord.Member, error) {
func (NoopStore) Member(discord.GuildID, discord.UserID) (*discord.Member, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Members(discord.Snowflake) ([]discord.Member, error) {
func (NoopStore) Members(discord.GuildID) ([]discord.Member, error) {
return nil, ErrNotImplemented
}
func (NoopStore) MemberSet(discord.Snowflake, *discord.Member) error {
func (NoopStore) MemberSet(discord.GuildID, discord.Member) error {
return nil
}
func (NoopStore) MemberRemove(_, _ discord.Snowflake) error {
func (NoopStore) MemberRemove(discord.GuildID, discord.UserID) error {
return nil
}
func (NoopStore) Message(_, _ discord.Snowflake) (*discord.Message, error) {
func (NoopStore) Message(discord.ChannelID, discord.MessageID) (*discord.Message, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Messages(discord.Snowflake) ([]discord.Message, error) {
func (NoopStore) Messages(discord.ChannelID) ([]discord.Message, error) {
return nil, ErrNotImplemented
}
@ -109,58 +109,58 @@ func (NoopStore) MaxMessages() int {
return 100
}
func (NoopStore) MessageSet(*discord.Message) error {
func (NoopStore) MessageSet(discord.Message) error {
return nil
}
func (NoopStore) MessageRemove(_, _ discord.Snowflake) error {
func (NoopStore) MessageRemove(discord.ChannelID, discord.MessageID) error {
return nil
}
func (NoopStore) Presence(_, _ discord.Snowflake) (*discord.Presence, error) {
func (NoopStore) Presence(discord.GuildID, discord.UserID) (*discord.Presence, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Presences(discord.Snowflake) ([]discord.Presence, error) {
func (NoopStore) Presences(discord.GuildID) ([]discord.Presence, error) {
return nil, ErrNotImplemented
}
func (NoopStore) PresenceSet(discord.Snowflake, *discord.Presence) error {
func (NoopStore) PresenceSet(discord.GuildID, discord.Presence) error {
return nil
}
func (NoopStore) PresenceRemove(_, _ discord.Snowflake) error {
func (NoopStore) PresenceRemove(discord.GuildID, discord.UserID) error {
return nil
}
func (NoopStore) Role(_, _ discord.Snowflake) (*discord.Role, error) {
func (NoopStore) Role(discord.GuildID, discord.RoleID) (*discord.Role, error) {
return nil, ErrNotImplemented
}
func (NoopStore) Roles(discord.Snowflake) ([]discord.Role, error) {
func (NoopStore) Roles(discord.GuildID) ([]discord.Role, error) {
return nil, ErrNotImplemented
}
func (NoopStore) RoleSet(discord.Snowflake, *discord.Role) error {
func (NoopStore) RoleSet(discord.GuildID, discord.Role) error {
return nil
}
func (NoopStore) RoleRemove(_, _ discord.Snowflake) error {
func (NoopStore) RoleRemove(discord.GuildID, discord.RoleID) error {
return nil
}
func (NoopStore) VoiceState(_, _ discord.Snowflake) (*discord.VoiceState, error) {
func (NoopStore) VoiceState(discord.GuildID, discord.UserID) (*discord.VoiceState, error) {
return nil, ErrNotImplemented
}
func (NoopStore) VoiceStates(_ discord.Snowflake) ([]discord.VoiceState, error) {
func (NoopStore) VoiceStates(discord.GuildID) ([]discord.VoiceState, error) {
return nil, ErrNotImplemented
}
func (NoopStore) VoiceStateSet(discord.Snowflake, *discord.VoiceState) error {
func (NoopStore) VoiceStateSet(discord.GuildID, discord.VoiceState) error {
return ErrNotImplemented
}
func (NoopStore) VoiceStateRemove(_, _ discord.Snowflake) error {
func (NoopStore) VoiceStateRemove(discord.GuildID, discord.UserID) error {
return ErrNotImplemented
}

336
utils/handler/handler.go Normal file
View file

@ -0,0 +1,336 @@
// Package handler handles incoming Gateway events. It reflects the function's
// first argument and caches that for use in each event.
//
// Performance
//
// Each call to the event would take 167 ns/op for roughly each handler. Scaling
// that up to 100 handlers is roughly the same as multiplying 167 ns by 100,
// which gives 16700 ns or 0.0167 ms.
//
// BenchmarkReflect-8 7260909 167 ns/op
//
// Usage
//
// Handler's usage is mostly similar to Discordgo, in that AddHandler expects a
// function with only one argument or an event channel. For more information,
// refer to AddHandler.
package handler
import (
"context"
"fmt"
"reflect"
"sync"
"github.com/pkg/errors"
)
type Handler struct {
// Synchronous controls whether to spawn each event handler in its own
// goroutine. Default false (meaning goroutines are spawned).
Synchronous bool
handlers map[uint64]handler
horders []uint64
hserial uint64
hmutex sync.RWMutex
}
func New() *Handler {
return &Handler{
handlers: map[uint64]handler{},
}
}
// Call calls all handlers with the given event. This is an internal method; use
// with care.
func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
if !ok {
// This shouldn't ever happen, but we're adding this just in case.
continue
}
if handler.not(evT) {
continue
}
if h.Synchronous {
handler.call(evV)
} else {
go handler.call(evV)
}
}
}
// CallDirect is the same as Call, but only calls those event handlers that
// listen for this specific event, i.e. that aren't interface handlers.
func (h *Handler) CallDirect(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
if !ok {
// This shouldn't ever happen, but we're adding this just in case.
continue
}
if evT != handler.event {
continue
}
if h.Synchronous {
handler.call(evV)
} else {
go handler.call(evV)
}
}
}
// WaitFor blocks until there's an event. It's advised to use ChanFor instead,
// as WaitFor may skip some events if it's not ran fast enough after the event
// arrived.
func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interface{} {
var result = make(chan interface{})
cancel := h.AddHandler(func(v interface{}) {
if fn(v) {
result <- v
}
})
defer cancel()
select {
case r := <-result:
return r
case <-ctx.Done():
return nil
}
}
// ChanFor returns a channel that would receive all incoming events that match
// the callback given. The cancel() function removes the handler and drops all
// hanging goroutines.
//
// This method is more intended to be used as a filter. For a persistent event
// channel, consider adding it directly as a handler with AddHandler.
func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, cancel func()) {
result := make(chan interface{})
closer := make(chan struct{})
removeHandler := h.AddHandler(func(v interface{}) {
if fn(v) {
select {
case result <- v:
case <-closer:
}
}
})
// Only allow cancel to be called once.
var once sync.Once
cancel = func() {
once.Do(func() {
removeHandler()
close(closer)
})
}
out = result
return
}
// AddHandler adds the handler, returning a function that would remove this
// handler when called. A handler type is either a single-argument no-return
// function or a channel.
//
// Function
//
// A handler can be a function with a single argument that is the expected event
// type. It must not have any returns or any other number of arguments.
//
// // An example of a valid function handler.
// h.AddHandler(func(*gateway.MessageCreateEvent) {})
//
// Channel
//
// A handler can also be a channel. The underlying type that the channel wraps
// around will be the event type. As such, the type rules are the same as
// function handlers.
//
// Keep in mind that the user must NOT close the channel. In fact, the channel
// should not be closed at all. The caller function WILL PANIC if the channel is
// closed!
//
// When the rm callback that is returned is called, it will also guarantee that
// all blocking sends will be cancelled. This helps prevent dangling goroutines.
//
// // An example of a valid channel handler.
// ch := make(chan *gateway.MessageCreateEvent)
// h.AddHandler(ch)
//
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler)
if err != nil {
panic(err)
}
return rm
}
// AddHandlerCheck adds the handler, but safe-guards reflect panics with a
// recoverer, returning the error. Refer to AddHandler for more information.
func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
// Reflect would actually panic if anything goes wrong, so this is just in
// case.
defer func() {
if rec := recover(); rec != nil {
if recErr, ok := rec.(error); ok {
err = recErr
} else {
err = fmt.Errorf("%v", rec)
}
}
}()
return h.addHandler(handler)
}
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
// Reflect the handler
r, err := newHandler(fn)
if err != nil {
return nil, errors.Wrap(err, "handler reflect failed")
}
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Get the current counter value and increment the counter:
serial := h.hserial
h.hserial++
// Create a map if there's none:
if h.handlers == nil {
h.handlers = map[uint64]handler{}
}
// Use the serial for the map:
h.handlers[serial] = *r
// Append the serial into the list of keys:
h.horders = append(h.horders, serial)
return func() {
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Take and delete the handler from the map, but return if we can't find
// it.
hd, ok := h.handlers[serial]
if !ok {
return
}
delete(h.handlers, serial)
// Delete the key from the orders slice:
for i, order := range h.horders {
if order == serial {
h.horders = append(h.horders[:i], h.horders[i+1:]...)
break
}
}
// Clean up the handler.
hd.cleanup()
}, nil
}
type handler struct {
event reflect.Type // underlying type; arg0 or chan underlying type
callback reflect.Value
isIface bool
chanclose reflect.Value // IsValid() if chan
}
// newHandler reflects either a channel or a function into a handler. A function
// must only have a single argument being the event and no return, and a channel
// must have the event type as the underlying type.
func newHandler(unknown interface{}) (*handler, error) {
fnV := reflect.ValueOf(unknown)
fnT := fnV.Type()
// underlying event type
var handler = handler{
callback: fnV,
}
switch fnT.Kind() {
case reflect.Func:
if fnT.NumIn() != 1 {
return nil, errors.New("function can only accept 1 event as argument")
}
if fnT.NumOut() > 0 {
return nil, errors.New("function can't accept returns")
}
handler.event = fnT.In(0)
case reflect.Chan:
handler.event = fnT.Elem()
handler.chanclose = reflect.ValueOf(make(chan struct{}))
default:
return nil, errors.New("given interface is not a function or channel")
}
var kind = handler.event.Kind()
// Accept either pointer type or interface{} type
if kind != reflect.Ptr && kind != reflect.Interface {
return nil, errors.New("first argument is not pointer")
}
handler.isIface = kind == reflect.Interface
return &handler, nil
}
func (h handler) not(event reflect.Type) bool {
if h.isIface {
return !event.Implements(h.event)
}
return h.event != event
}
func (h *handler) call(event reflect.Value) {
if h.chanclose.IsValid() {
reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: h.callback, Send: event},
{Dir: reflect.SelectRecv, Chan: h.chanclose},
})
} else {
h.callback.Call([]reflect.Value{event})
}
}
func (h *handler) cleanup() {
if h.chanclose.IsValid() {
// Closing this channel will force all ongoing selects to return
// immediately.
h.chanclose.Close()
}
}

View file

@ -35,7 +35,7 @@ func TestCall(t *testing.T) {
t.Fatal("Returned results is wrong:", r)
}
// Remove handler test
// Delete handler test
rm()
go h.Call(newMessage("astolfo"))
@ -63,7 +63,7 @@ func TestCall(t *testing.T) {
func TestHandler(t *testing.T) {
var results = make(chan string)
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {
h, err := newHandler(func(m *gateway.MessageCreateEvent) {
results <- m.Content
})
if err != nil {
@ -87,10 +87,81 @@ func TestHandler(t *testing.T) {
}
}
func TestHandlerChan(t *testing.T) {
var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results)
if err != nil {
t.Fatal(err)
}
const result = "Hime Arikawa"
var msg = newMessage(result)
var msgV = reflect.ValueOf(msg)
var msgT = msgV.Type()
if h.not(msgT) {
t.Fatal("Event type mismatch")
}
go h.call(msgV)
if results := <-results; results.Content != result {
t.Fatal("Unexpected results:", results)
}
}
func TestHandlerChanCancel(t *testing.T) {
// Never receive from this channel. It is important that this channel is
// unbuffered.
var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results)
if err != nil {
t.Fatal(err)
}
const result = "Hime Arikawa"
var msg = newMessage(result)
var msgV = reflect.ValueOf(msg)
var msgT = msgV.Type()
if h.not(msgT) {
t.Fatal("Event type mismatch")
}
// Channel that waits for call() to die.
die := make(chan struct{})
// Call in a goroutine, which would trigger a close.
go func() { h.call(msgV); die <- struct{}{} }()
// Call the cleanup function, which should stop the send.
h.cleanup()
// Check if we still have things being sent.
select {
case <-die:
// pass
case <-time.After(200 * time.Millisecond):
t.Fatal("Timed out waiting for call routine to die.")
}
// Check if we still receive something.
select {
case <-results:
t.Fatal("Unexpected results received.")
default:
// pass
}
}
func TestHandlerInterface(t *testing.T) {
var results = make(chan interface{})
h, err := reflectFn(func(m interface{}) {
h, err := newHandler(func(m interface{}) {
results <- m
})
if err != nil {
@ -121,7 +192,7 @@ func TestHandlerInterface(t *testing.T) {
t.Fatal("Assertion failed:", recv)
}
func TestHandlerWait(t *testing.T) {
func TestHandlerWaitFor(t *testing.T) {
inc := make(chan interface{}, 1)
h := New()
@ -173,7 +244,7 @@ func TestHandlerWait(t *testing.T) {
}
}
func TestHandlerChan(t *testing.T) {
func TestHandlerChanFor(t *testing.T) {
h := New()
wanted := &gateway.TypingStartEvent{
@ -208,7 +279,7 @@ func TestHandlerChan(t *testing.T) {
}
func BenchmarkReflect(b *testing.B) {
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {})
h, err := newHandler(func(m *gateway.MessageCreateEvent) {})
if err != nil {
b.Fatal(err)
}

View file

@ -1,172 +0,0 @@
// Package heart implements a general purpose pacemaker.
package heart
import (
"sync"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
// Debug is the default logger that Pacemaker uses.
var Debug = func(v ...interface{}) {}
var ErrDead = errors.New("no heartbeat replied")
// AtomicTime is a thread-safe UnixNano timestamp guarded by atomic.
type AtomicTime struct {
unixnano int64
}
func (t *AtomicTime) Get() int64 {
return atomic.LoadInt64(&t.unixnano)
}
func (t *AtomicTime) Set(time time.Time) {
atomic.StoreInt64(&t.unixnano, time.UnixNano())
}
func (t *AtomicTime) Time() time.Time {
return time.Unix(0, t.Get())
}
type atomicStop atomic.Value
func (s *atomicStop) Stop() bool {
if v := (*atomic.Value)(s).Load(); v != nil {
ch := v.(chan struct{})
close(ch)
return true
}
return false
}
func (s *atomicStop) Recv() <-chan struct{} {
if v := (*atomic.Value)(s).Load(); v != nil {
return v.(chan struct{})
}
return nil
}
func (s *atomicStop) SetNil() {
(*atomic.Value)(s).Store((chan struct{})(nil))
}
func (s *atomicStop) Reset() {
(*atomic.Value)(s).Store(make(chan struct{}))
}
type Pacemaker struct {
// Heartrate is the received duration between heartbeats.
Heartrate time.Duration
// Time in nanoseconds, guarded by atomic read/writes.
SentBeat AtomicTime
EchoBeat AtomicTime
// Any callback that returns an error will stop the pacer.
Pace func() error
stop atomicStop
death chan error
}
func NewPacemaker(heartrate time.Duration, pacer func() error) *Pacemaker {
return &Pacemaker{
Heartrate: heartrate,
Pace: pacer,
}
}
func (p *Pacemaker) Echo() {
// Swap our received heartbeats
// p.LastBeat[0], p.LastBeat[1] = time.Now(), p.LastBeat[0]
p.EchoBeat.Set(time.Now())
}
// Dead, if true, will have Pace return an ErrDead.
func (p *Pacemaker) Dead() bool {
/* Deprecated
if p.LastBeat[0].IsZero() || p.LastBeat[1].IsZero() {
return false
}
return p.LastBeat[0].Sub(p.LastBeat[1]) > p.Heartrate*2
*/
var (
echo = p.EchoBeat.Get()
sent = p.SentBeat.Get()
)
if echo == 0 || sent == 0 {
return false
}
return sent-echo > int64(p.Heartrate)*2
}
func (p *Pacemaker) Stop() {
if p.stop.Stop() {
Debug("(*Pacemaker).stop was sent a stop signal.")
} else {
Debug("(*Pacemaker).stop is nil, skipping.")
}
}
func (p *Pacemaker) start() error {
// Reset states to its old position.
p.EchoBeat.Set(time.Time{})
p.SentBeat.Set(time.Time{})
// Create a new ticker.
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
// Echo at least once
p.Echo()
for {
if err := p.Pace(); err != nil {
return err
}
// Paced, save:
p.SentBeat.Set(time.Now())
if p.Dead() {
return ErrDead
}
select {
case <-p.stop.Recv():
return nil
case <-tick.C:
}
}
}
// StartAsync starts the pacemaker asynchronously. The WaitGroup is optional.
func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) {
p.death = make(chan error)
p.stop.Reset()
if wg != nil {
wg.Add(1)
}
go func() {
p.death <- p.start()
// Debug.
Debug("Pacemaker returned.")
// Mark the stop channel as nil, so later Close() calls won't block forever.
p.stop.SetNil()
// Mark the pacemaker loop as done.
if wg != nil {
wg.Done()
}
}()
return p.death
}

View file

@ -67,19 +67,22 @@ func (c *Client) Context() context.Context {
return c.context
}
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
// applyOptions tries to apply all options. It does not halt if a single option
// fails, and the error returned is the latest error.
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) (e error) {
for _, opt := range c.OnRequest {
if err := opt(r); err != nil {
return err
}
}
for _, opt := range extra {
if err := opt(r); err != nil {
return err
e = err
}
}
return nil
for _, opt := range extra {
if err := opt(r); err != nil {
e = err
}
}
return
}
func (c *Client) MeanwhileMultipart(
@ -158,6 +161,8 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
var r httpdriver.Response
var status int
// The c.Retries < 1 check ensures that we retry forever if that field is
// less than 1.
for i := uint(0); c.Retries < 1 || i < c.Retries; i++ {
q, err := c.Client.NewRequest(c.context, method, url)
if err != nil {
@ -165,18 +170,33 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
}
if err := c.applyOptions(q, opts); err != nil {
return nil, errors.Wrap(err, "Failed to apply options")
// We failed to apply an option, so we should call all OnResponse
// handler to clean everything up.
for _, fn := range c.OnResponse {
fn(q, nil)
}
// Exit after cleaning everything up.
return nil, errors.Wrap(err, "failed to apply options")
}
r, doErr = c.Client.Do(q)
// Error that represents the latest error in the chain.
var onRespErr error
// Call OnResponse() even if the request failed.
for _, fn := range c.OnResponse {
// Be sure to call ALL OnResponse handlers.
if err := fn(q, r); err != nil {
return nil, err
onRespErr = err
}
}
if onRespErr != nil {
return nil, errors.Wrap(err, "OnResponse handler failed")
}
// Retry if the request failed.
if doErr != nil {
continue
}

View file

@ -22,7 +22,7 @@ type RequestError struct {
}
func (r RequestError) Error() string {
return "Request failed: " + r.err.Error()
return "request failed: " + r.err.Error()
}
func (r RequestError) Unwrap() error {

Some files were not shown because too many files have changed in this diff Show more