diff options
-rw-r--r-- | internal/command/twofactorrecover/twofactorrecover.go | 10 | ||||
-rw-r--r-- | internal/command/twofactorverify/twofactorverify.go | 8 | ||||
-rw-r--r-- | internal/sshd/session.go | 16 |
3 files changed, 23 insertions, 11 deletions
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go index 91eca3a..6982689 100644 --- a/internal/command/twofactorrecover/twofactorrecover.go +++ b/internal/command/twofactorrecover/twofactorrecover.go @@ -26,7 +26,7 @@ func (c *Command) Execute(ctx context.Context) error { ctxlog := log.ContextLogger(ctx) ctxlog.Debug("twofactorrecover: execute: Waiting for user input") - if c.canContinue() { + if c.getUserAnswer(ctx) == "yes" { ctxlog.Debug("twofactorrecover: execute: User chose to continue") c.displayRecoveryCodes(ctx) } else { @@ -37,16 +37,18 @@ func (c *Command) Execute(ctx context.Context) error { return nil } -func (c *Command) canContinue() bool { +func (c *Command) getUserAnswer(ctx context.Context) string { question := "Are you sure you want to generate new two-factor recovery codes?\n" + "Any existing recovery codes you saved will be invalidated. (yes/no)" fmt.Fprintln(c.ReadWriter.Out, question) var answer string - fmt.Fscanln(io.LimitReader(c.ReadWriter.In, readerLimit), &answer) + if _, err := fmt.Fscanln(io.LimitReader(c.ReadWriter.In, readerLimit), &answer); err != nil { + log.ContextLogger(ctx).WithError(err).Debug("twofactorrecover: getUserAnswer: Failed to get user input") + } - return answer == "yes" + return answer } func (c *Command) displayRecoveryCodes(ctx context.Context) { diff --git a/internal/command/twofactorverify/twofactorverify.go b/internal/command/twofactorverify/twofactorverify.go index fe17339..0099e84 100644 --- a/internal/command/twofactorverify/twofactorverify.go +++ b/internal/command/twofactorverify/twofactorverify.go @@ -22,7 +22,7 @@ type Command struct { func (c *Command) Execute(ctx context.Context) error { ctxlog := log.ContextLogger(ctx) ctxlog.Info("twofactorverify: execute: waiting for user input") - otp := c.getOTP() + otp := c.getOTP(ctx) ctxlog.Info("twofactorverify: execute: verifying entered OTP") err := c.verifyOTP(ctx, otp) @@ -35,14 +35,16 @@ func (c *Command) Execute(ctx context.Context) error { return nil } -func (c *Command) getOTP() string { +func (c *Command) getOTP(ctx context.Context) string { prompt := "OTP: " fmt.Fprint(c.ReadWriter.Out, prompt) var answer string otpLength := int64(64) reader := io.LimitReader(c.ReadWriter.In, otpLength) - fmt.Fscanln(reader, &answer) + if _, err := fmt.Fscanln(reader, &answer); err != nil { + log.ContextLogger(ctx).WithError(err).Debug("twofactorverify: getOTP: Failed to get user input") + } return answer } diff --git a/internal/sshd/session.go b/internal/sshd/session.go index b8e8625..ff8540b 100644 --- a/internal/sshd/session.go +++ b/internal/sshd/session.go @@ -66,8 +66,11 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) { default: // Ignore unknown requests but don't terminate the session shouldContinue = true + if req.WantReply { - req.Reply(false, []byte{}) + if err := req.Reply(false, []byte{}); err != nil { + sessionLog.WithError(err).Debug("session: handle: Failed to reply") + } } } @@ -100,7 +103,9 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool { } if req.WantReply { - req.Reply(accepted, []byte{}) + if err := req.Reply(accepted, []byte{}); err != nil { + log.ContextLogger(ctx).WithError(err).Debug("session: handleEnv: Failed to reply") + } } log.WithContextFields( @@ -124,8 +129,12 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool { } func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { + ctxlog := log.ContextLogger(ctx) + if req.WantReply { - req.Reply(true, []byte{}) + if err := req.Reply(true, []byte{}); err != nil { + ctxlog.WithError(err).Debug("session: handleShell: Failed to reply") + } } env := sshenv.Env{ @@ -151,7 +160,6 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 { } cmdName := reflect.TypeOf(cmd).String() - ctxlog := log.ContextLogger(ctx) ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command") if err := cmd.Execute(ctx); err != nil { |