diff --git a/bot.go b/bot.go index c6ca084..fbb518e 100644 --- a/bot.go +++ b/bot.go @@ -24,9 +24,8 @@ type BotAPI struct { Debug bool `json:"debug"` Buffer int `json:"buffer"` - Self User `json:"-"` - Client *http.Client `json:"-"` - shutdownChannel chan interface{} + Self User `json:"-"` + Client *http.Client `json:"-"` } // NewBotAPI creates a new BotAPI instance. @@ -42,10 +41,9 @@ func NewBotAPI(token string) (*BotAPI, error) { // It requires a token, provided by @BotFather on Telegram. func NewBotAPIWithClient(token string, client *http.Client) (*BotAPI, error) { bot := &BotAPI{ - Token: token, - Client: client, - Buffer: 100, - shutdownChannel: make(chan interface{}), + Token: token, + Client: client, + Buffer: 100, } self, err := bot.GetMe() @@ -395,11 +393,12 @@ func (bot *BotAPI) GetWebhookInfo() (WebhookInfo, error) { // GetUpdatesChan starts and returns a channel for getting updates. func (bot *BotAPI) GetUpdatesChan(config UpdateConfig) UpdatesChannel { ch := make(chan Update, bot.Buffer) + done := make(chan struct{}) go func() { for { select { - case <-bot.shutdownChannel: + case <-done: return default: } @@ -416,28 +415,37 @@ func (bot *BotAPI) GetUpdatesChan(config UpdateConfig) UpdatesChannel { for _, update := range updates { if update.UpdateID >= config.Offset { config.Offset = update.UpdateID + 1 - ch <- update + + select { + case ch <- update: + case <-done: + return + } } } } }() - return ch -} - -// StopReceivingUpdates stops the go routine which receives updates -func (bot *BotAPI) StopReceivingUpdates() { - if bot.Debug { - log.Println("Stopping the update receiver routine...") + updatesCh := UpdatesChannel{ + channel: ch, + done: done, } - close(bot.shutdownChannel) + + return updatesCh } // ListenForWebhook registers a http handler for a webhook. func (bot *BotAPI) ListenForWebhook(pattern string) UpdatesChannel { ch := make(chan Update, bot.Buffer) + done := make(chan struct{}) http.HandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) { + select { + case <-done: + return + default: + } + bytes, _ := ioutil.ReadAll(r.Body) var update Update @@ -446,7 +454,12 @@ func (bot *BotAPI) ListenForWebhook(pattern string) UpdatesChannel { ch <- update }) - return ch + updatesCh := UpdatesChannel{ + channel: ch, + done: done, + } + + return updatesCh } // GetChat gets information about a chat. diff --git a/bot_test.go b/bot_test.go index fe1fd55..41a5424 100644 --- a/bot_test.go +++ b/bot_test.go @@ -552,7 +552,7 @@ func ExampleNewBotAPI() { time.Sleep(time.Millisecond * 500) updates.Clear() - for update := range updates { + for update := range updates.Channel() { if update.Message == nil { continue } @@ -594,7 +594,7 @@ func ExampleNewWebhook() { updates := bot.ListenForWebhook("/" + bot.Token) go http.ListenAndServeTLS("0.0.0.0:8443", "cert.pem", "key.pem", nil) - for update := range updates { + for update := range updates.Channel() { log.Printf("%+v\n", update) } } @@ -612,7 +612,7 @@ func ExampleInlineConfig() { updates := bot.GetUpdatesChan(u) - for update := range updates { + for update := range updates.Channel() { if update.InlineQuery == nil { // if no inline query, ignore it continue } diff --git a/types.go b/types.go index 78ebcf3..80d30d5 100644 --- a/types.go +++ b/types.go @@ -40,13 +40,28 @@ type Update struct { Poll *Poll `json:"poll"` } -// UpdatesChannel is the channel for getting updates. -type UpdatesChannel <-chan Update +// UpdatesChannel is the struct that holds a channel for getting updates. +type UpdatesChannel struct { + channel chan Update + done chan struct{} +} + +// Return Update channel +func (updatesCh UpdatesChannel) Channel() <-chan Update { + return updatesCh.channel +} + +// Stop channel feeding by goroutine or http handlers. +// +// It may not feed the update channel with all fetched/received Updates. +func (updatesCh UpdatesChannel) Shutdown() { + updatesCh.done <- struct{}{} +} // Clear discards all unprocessed incoming updates. -func (ch UpdatesChannel) Clear() { - for len(ch) != 0 { - <-ch +func (updatesCh UpdatesChannel) Clear() { + for len(updatesCh.channel) != 0 { + <-updatesCh.channel } }