summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--internal/command/twofactorrecover/twofactorrecover.go10
-rw-r--r--internal/command/twofactorverify/twofactorverify.go8
-rw-r--r--internal/sshd/session.go16
3 files changed, 23 insertions, 11 deletions
diff --git a/internal/command/twofactorrecover/twofactorrecover.go b/internal/command/twofactorrecover/twofactorrecover.go
index 91eca3a..6982689 100644
--- a/internal/command/twofactorrecover/twofactorrecover.go
+++ b/internal/command/twofactorrecover/twofactorrecover.go
@@ -26,7 +26,7 @@ func (c *Command) Execute(ctx context.Context) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Debug("twofactorrecover: execute: Waiting for user input")
- if c.canContinue() {
+ if c.getUserAnswer(ctx) == "yes" {
ctxlog.Debug("twofactorrecover: execute: User chose to continue")
c.displayRecoveryCodes(ctx)
} else {
@@ -37,16 +37,18 @@ func (c *Command) Execute(ctx context.Context) error {
return nil
}
-func (c *Command) canContinue() bool {
+func (c *Command) getUserAnswer(ctx context.Context) string {
question :=
"Are you sure you want to generate new two-factor recovery codes?\n" +
"Any existing recovery codes you saved will be invalidated. (yes/no)"
fmt.Fprintln(c.ReadWriter.Out, question)
var answer string
- fmt.Fscanln(io.LimitReader(c.ReadWriter.In, readerLimit), &answer)
+ if _, err := fmt.Fscanln(io.LimitReader(c.ReadWriter.In, readerLimit), &answer); err != nil {
+ log.ContextLogger(ctx).WithError(err).Debug("twofactorrecover: getUserAnswer: Failed to get user input")
+ }
- return answer == "yes"
+ return answer
}
func (c *Command) displayRecoveryCodes(ctx context.Context) {
diff --git a/internal/command/twofactorverify/twofactorverify.go b/internal/command/twofactorverify/twofactorverify.go
index fe17339..0099e84 100644
--- a/internal/command/twofactorverify/twofactorverify.go
+++ b/internal/command/twofactorverify/twofactorverify.go
@@ -22,7 +22,7 @@ type Command struct {
func (c *Command) Execute(ctx context.Context) error {
ctxlog := log.ContextLogger(ctx)
ctxlog.Info("twofactorverify: execute: waiting for user input")
- otp := c.getOTP()
+ otp := c.getOTP(ctx)
ctxlog.Info("twofactorverify: execute: verifying entered OTP")
err := c.verifyOTP(ctx, otp)
@@ -35,14 +35,16 @@ func (c *Command) Execute(ctx context.Context) error {
return nil
}
-func (c *Command) getOTP() string {
+func (c *Command) getOTP(ctx context.Context) string {
prompt := "OTP: "
fmt.Fprint(c.ReadWriter.Out, prompt)
var answer string
otpLength := int64(64)
reader := io.LimitReader(c.ReadWriter.In, otpLength)
- fmt.Fscanln(reader, &answer)
+ if _, err := fmt.Fscanln(reader, &answer); err != nil {
+ log.ContextLogger(ctx).WithError(err).Debug("twofactorverify: getOTP: Failed to get user input")
+ }
return answer
}
diff --git a/internal/sshd/session.go b/internal/sshd/session.go
index b8e8625..ff8540b 100644
--- a/internal/sshd/session.go
+++ b/internal/sshd/session.go
@@ -66,8 +66,11 @@ func (s *session) handle(ctx context.Context, requests <-chan *ssh.Request) {
default:
// Ignore unknown requests but don't terminate the session
shouldContinue = true
+
if req.WantReply {
- req.Reply(false, []byte{})
+ if err := req.Reply(false, []byte{}); err != nil {
+ sessionLog.WithError(err).Debug("session: handle: Failed to reply")
+ }
}
}
@@ -100,7 +103,9 @@ func (s *session) handleEnv(ctx context.Context, req *ssh.Request) bool {
}
if req.WantReply {
- req.Reply(accepted, []byte{})
+ if err := req.Reply(accepted, []byte{}); err != nil {
+ log.ContextLogger(ctx).WithError(err).Debug("session: handleEnv: Failed to reply")
+ }
}
log.WithContextFields(
@@ -124,8 +129,12 @@ func (s *session) handleExec(ctx context.Context, req *ssh.Request) bool {
}
func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
+ ctxlog := log.ContextLogger(ctx)
+
if req.WantReply {
- req.Reply(true, []byte{})
+ if err := req.Reply(true, []byte{}); err != nil {
+ ctxlog.WithError(err).Debug("session: handleShell: Failed to reply")
+ }
}
env := sshenv.Env{
@@ -151,7 +160,6 @@ func (s *session) handleShell(ctx context.Context, req *ssh.Request) uint32 {
}
cmdName := reflect.TypeOf(cmd).String()
- ctxlog := log.ContextLogger(ctx)
ctxlog.WithFields(log.Fields{"env": env, "command": cmdName}).Info("session: handleShell: executing command")
if err := cmd.Execute(ctx); err != nil {