Concurrency Transaction in Go is not working

56 Views Asked by At

I am relatively new to Go Lang and am currently undertaking a learning journey through the Udemy course found at https://www.udemy.com/course/backend-master-class-golang-postgresql-kubernetes. As part of this course, I am attempting to develop a basic banking application using Go.

However, I have encountered an issue with the transaction functionality in my application. Specifically, only the first transaction seems to be successful, while subsequent transactions do not update as expected. I am keen to understand the reasons behind this failure and would appreciate any insights or guidance on resolving the issue.

I've attempted to troubleshoot the issue by extensively logging information and inspecting the PostgreSQL locks query. Despite my efforts, I'm still unable to identify the root cause of the problem in my code. I would greatly appreciate any assistance from the community in pinpointing and resolving the issue. Please let me know if there are specific parts of my code or additional details that would be helpful in diagnosing the problem.

store.go


import (
    "context"
    "database/sql"
    "fmt"
)

//Store provides all functions to execute db queries and transactions

type Store struct {
    *Queries
    db *sql.DB
}

// NewStore creates a new store
func NewStore(db *sql.DB) *Store {
    return &Store{
        db:      db,
        Queries: New(db),
    }
}

// execTx executes a function within a database transaction
func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
    tx, err := store.db.BeginTx(ctx, nil)

    if err != nil {
        return err
    }

    q := New(tx)
    err = fn(q)

    if err != nil {
        if rbErr := tx.Rollback(); rbErr != nil {
            return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
        }
        return err
    }

    return tx.Commit()
}

// TransferTxParams contains the input parameters of the transfer transaction
type TransferTxParams struct {
    FromAccountID int64 `json:"from_account_id"`
    ToAccountID   int64 `json:"to_account_id"`
    Amount        int64 `json:"amount"`
}

// TransferTxResult is the result of the transfer transaction
type TransferTxResult struct {
    Transfer    Transfer `json:"transfer"`
    FromAccount Account  `json:"from_account"`
    ToAccount   Account  `json:"to_account"`
    FromEntry   Entry    `json:"from_entry"`
    ToEntry     Entry    `json:"to_entry"`
}

var txKey = struct{}{}

// TransferTx performs a money transfer from one account to the oother
// It creates a transfer record, add account entries, and update accounts' balance within a single database transaction
func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
    var result TransferTxResult

    err := store.execTx(ctx, func(q *Queries) error {
        var err error

        result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
            FromAccountID: arg.FromAccountID,
            ToAccountID:   arg.ToAccountID,
            Amount:        arg.Amount,
        })

        if err != nil {
            return err
        }

        result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
            AccountID: arg.FromAccountID,
            Amount:    -arg.Amount,
        })

        if err != nil {
            return err
        }

        result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
            AccountID: arg.ToAccountID,
            Amount:    arg.Amount,
        })

        if err != nil {
            return err
        }

        result.FromAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
            ID:     arg.FromAccountID,
            Amount: -arg.Amount,
        })

        if err != nil {
            return err
        }

        result.ToAccount, err = q.AddAccountBalance(ctx, AddAccountBalanceParams{
            ID:     arg.ToAccountID,
            Amount: arg.Amount,
        })

        if err != nil {
            return err
        }

        // TODO: update acccounts balance

        return nil

    })

    return result, err

}

store_test.go


import (
    "context"
    "fmt"
    "log"
    "testing"

    "github.com/stretchr/testify/require"
)

func TestTransferTx(t *testing.T) {
    store := NewStore(testDB)

    account1 := createRandomAccount(t)
    account2 := createRandomAccount(t)

    fmt.Println(">> before:", account1.Balance, account2.Balance)

    n := 3
    amount := int64(10)

    errs := make(chan error)
    results := make(chan TransferTxResult)

    for i := 0; i < n; i++ {
        txName := fmt.Sprintf("tx %d", i+1)
        go func() {
            ctx := context.Background()
            result, err := store.TransferTx(ctx, TransferTxParams{
                FromAccountID: account1.ID,
                ToAccountID:   account2.ID,
                Amount:        amount,
            })

            if err != nil {
                log.Printf("%v: Error in transaction: %v\n", txName, err)
            } else {
                log.Printf("%v: Transaction successful\n", txName)
            }

            errs <- err
            results <- result
        }()
    }

    // check results
    existed := make(map[int]bool)
    for i := 0; i < n; i++ {
        err := <-errs
        require.NoError(t, err)

        result := <-results
        require.NotEmpty(t, result)

        //check transfer
        transfer := result.Transfer
        require.NotEmpty(t, transfer)
        require.Equal(t, account1.ID, transfer.FromAccountID)
        require.Equal(t, account2.ID, transfer.ToAccountID)
        require.Equal(t, amount, transfer.Amount)
        require.NotZero(t, transfer.ID)
        require.NotZero(t, transfer.CreatedAt)

        _, err = store.GetTransfer(context.Background(), transfer.ID)
        require.NoError(t, err)

        //check entries
        fromEntry := result.FromEntry
        require.NotEmpty(t, fromEntry)
        require.Equal(t, account1.ID, fromEntry.AccountID)
        require.Equal(t, -amount, fromEntry.Amount)
        require.NotZero(t, fromEntry.ID)
        require.NotZero(t, fromEntry.CreatedAt)

        _, err = store.GetEntry(context.Background(), fromEntry.ID)
        require.NoError(t, err)

        //check entries
        toEntry := result.ToEntry
        require.NotEmpty(t, toEntry)
        require.Equal(t, account2.ID, toEntry.AccountID)
        require.Equal(t, amount, toEntry.Amount)
        require.NotZero(t, toEntry.ID)
        require.NotZero(t, toEntry.CreatedAt)

        _, err = store.GetEntry(context.Background(), toEntry.ID)
        require.NoError(t, err)

        //check account
        fromAccount := result.FromAccount
        require.NotEmpty(t, fromAccount)
        require.Equal(t, account1.ID, fromAccount.ID)

        toAccount := result.ToAccount
        require.NotEmpty(t, toAccount)
        require.Equal(t, account2.ID, toAccount.ID)

        //check accounts balance
        fmt.Println(">> tx:", fromAccount.Balance, toAccount.Balance)
        diff1 := account1.Balance - fromAccount.Balance
        diff2 := toAccount.Balance - account2.Balance
        require.Equal(t, diff1, diff2)
        require.True(t, diff1 > 0)
        require.True(t, diff1%amount == 0)

        k := int(diff1 / amount)
        require.True(t, k >= 1 && k <= n)
        require.NotContains(t, existed, k)
        existed[k] = true

        // check the final updated balances

        updatedAccount1, err := testQueries.GetAccount(context.Background(), account1.ID)
        require.NoError(t, err)

        updatedAccount2, err := testQueries.GetAccount(context.Background(), account2.ID)
        require.NoError(t, err)

        fmt.Println(">> after:", updatedAccount1.Balance, updatedAccount2.Balance)
        // fmt.Println(">> after:", account1.Balance, account2.Balance)

        require.Equal(t, account1.Balance-int64(n)*amount, updatedAccount1.Balance)
        require.Equal(t, account2.Balance+int64(n)*amount, updatedAccount2.Balance)

    }
}

account_sql.go(using sqlc)

// source: account.sql

package db

import (
    "context"
)

const addAccountBalance = `-- name: AddAccountBalance :one
UPDATE account
SET balance = balance + $1
WHERE id = $2
RETURNING id, owner, balance, currency, created_at
`

type AddAccountBalanceParams struct {
    Amount int64 `json:"amount"`
    ID     int64 `json:"id"`
}

func (q *Queries) AddAccountBalance(ctx context.Context, arg AddAccountBalanceParams) (Account, error) {
    row := q.db.QueryRowContext(ctx, addAccountBalance, arg.Amount, arg.ID)
    var i Account
    err := row.Scan(
        &i.ID,
        &i.Owner,
        &i.Balance,
        &i.Currency,
        &i.CreatedAt,
    )
    return i, err
}

const createAccount = `-- name: CreateAccount :one
INSERT INTO account (
  owner, balance, currency
) VALUES (
  $1, $2, $3
) RETURNING id, owner, balance, currency, created_at
`

type CreateAccountParams struct {
    Owner    string `json:"owner"`
    Balance  int64  `json:"balance"`
    Currency string `json:"currency"`
}

func (q *Queries) CreateAccount(ctx context.Context, arg CreateAccountParams) (Account, error) {
    row := q.db.QueryRowContext(ctx, createAccount, arg.Owner, arg.Balance, arg.Currency)
    var i Account
    err := row.Scan(
        &i.ID,
        &i.Owner,
        &i.Balance,
        &i.Currency,
        &i.CreatedAt,
    )
    return i, err
}

const deleteAccount = `-- name: DeleteAccount :exec
DELETE FROM account
WHERE id = $1
`

func (q *Queries) DeleteAccount(ctx context.Context, id int64) error {
    _, err := q.db.ExecContext(ctx, deleteAccount, id)
    return err
}

const getAccount = `-- name: GetAccount :one
SELECT id, owner, balance, currency, created_at FROM account
WHERE id = $1 LIMIT 1
`

func (q *Queries) GetAccount(ctx context.Context, id int64) (Account, error) {
    row := q.db.QueryRowContext(ctx, getAccount, id)
    var i Account
    err := row.Scan(
        &i.ID,
        &i.Owner,
        &i.Balance,
        &i.Currency,
        &i.CreatedAt,
    )
    return i, err
}

const getAccountForUpdate = `-- name: GetAccountForUpdate :one
SELECT id, owner, balance, currency, created_at FROM account
WHERE id = $1 LIMIT 1 FOR NO KEY UPDATE
`

func (q *Queries) GetAccountForUpdate(ctx context.Context, id int64) (Account, error) {
    row := q.db.QueryRowContext(ctx, getAccountForUpdate, id)
    var i Account
    err := row.Scan(
        &i.ID,
        &i.Owner,
        &i.Balance,
        &i.Currency,
        &i.CreatedAt,
    )
    return i, err
}

const listAccounts = `-- name: ListAccounts :many
SELECT id, owner, balance, currency, created_at FROM account
ORDER BY id
LIMIT $1
OFFSET $2
`

type ListAccountsParams struct {
    Limit  int32 `json:"limit"`
    Offset int32 `json:"offset"`
}

func (q *Queries) ListAccounts(ctx context.Context, arg ListAccountsParams) ([]Account, error) {
    rows, err := q.db.QueryContext(ctx, listAccounts, arg.Limit, arg.Offset)
    if err != nil {
        return nil, err
    }
    defer rows.Close()
    var items []Account
    for rows.Next() {
        var i Account
        if err := rows.Scan(
            &i.ID,
            &i.Owner,
            &i.Balance,
            &i.Currency,
            &i.CreatedAt,
        ); err != nil {
            return nil, err
        }
        items = append(items, i)
    }
    if err := rows.Close(); err != nil {
        return nil, err
    }
    if err := rows.Err(); err != nil {
        return nil, err
    }
    return items, nil
}

const updateAccount = `-- name: UpdateAccount :one
UPDATE account
SET balance = $2
WHERE id = $1
RETURNING id, owner, balance, currency, created_at
`

type UpdateAccountParams struct {
    ID      int64 `json:"id"`
    Balance int64 `json:"balance"`
}

func (q *Queries) UpdateAccount(ctx context.Context, arg UpdateAccountParams) (Account, error) {
    row := q.db.QueryRowContext(ctx, updateAccount, arg.ID, arg.Balance)
    var i Account
    err := row.Scan(
        &i.ID,
        &i.Owner,
        &i.Balance,
        &i.Currency,
        &i.CreatedAt,
    )
    return i, err
}

transaction_sql.go

// source: transfer.sql

package db

import (
    "context"
)

const createTransfer = `-- name: CreateTransfer :one
INSERT INTO transfers (
  from_account_id,
  to_account_id,
  amount
) VALUES (
  $1, $2, $3
) RETURNING id, from_account_id, to_account_id, amount, created_at
`

type CreateTransferParams struct {
    FromAccountID int64 `json:"from_account_id"`
    ToAccountID   int64 `json:"to_account_id"`
    Amount        int64 `json:"amount"`
}

func (q *Queries) CreateTransfer(ctx context.Context, arg CreateTransferParams) (Transfer, error) {
    row := q.db.QueryRowContext(ctx, createTransfer, arg.FromAccountID, arg.ToAccountID, arg.Amount)
    var i Transfer
    err := row.Scan(
        &i.ID,
        &i.FromAccountID,
        &i.ToAccountID,
        &i.Amount,
        &i.CreatedAt,
    )
    return i, err
}

const getTransfer = `-- name: GetTransfer :one
SELECT id, from_account_id, to_account_id, amount, created_at FROM transfers
WHERE id = $1 LIMIT 1
`

func (q *Queries) GetTransfer(ctx context.Context, id int64) (Transfer, error) {
    row := q.db.QueryRowContext(ctx, getTransfer, id)
    var i Transfer
    err := row.Scan(
        &i.ID,
        &i.FromAccountID,
        &i.ToAccountID,
        &i.Amount,
        &i.CreatedAt,
    )
    return i, err
}

const listTransfers = `-- name: ListTransfers :many
SELECT id, from_account_id, to_account_id, amount, created_at FROM transfers
WHERE 
    from_account_id = $1 OR
    to_account_id = $2
ORDER BY id
LIMIT $3
OFFSET $4
`

type ListTransfersParams struct {
    FromAccountID int64 `json:"from_account_id"`
    ToAccountID   int64 `json:"to_account_id"`
    Limit         int32 `json:"limit"`
    Offset        int32 `json:"offset"`
}

func (q *Queries) ListTransfers(ctx context.Context, arg ListTransfersParams) ([]Transfer, error) {
    rows, err := q.db.QueryContext(ctx, listTransfers,
        arg.FromAccountID,
        arg.ToAccountID,
        arg.Limit,
        arg.Offset,
    )
    if err != nil {
        return nil, err
    }
    defer rows.Close()
    var items []Transfer
    for rows.Next() {
        var i Transfer
        if err := rows.Scan(
            &i.ID,
            &i.FromAccountID,
            &i.ToAccountID,
            &i.Amount,
            &i.CreatedAt,
        ); err != nil {
            return nil, err
        }
        items = append(items, i)
    }
    if err := rows.Close(); err != nil {
        return nil, err
    }
    if err := rows.Err(); err != nil {
        return nil, err
    }
    return items, nil
}
0

There are 0 best solutions below