or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

docs

admin.mdadvanced.mdclient-server.mdcredentials-security.mderrors-status.mdhealth.mdindex.mdinterceptors.mdload-balancing.mdmetadata-context.mdname-resolution.mdobservability.mdreflection.mdstreaming.mdtesting.mdxds.md
tile.json

testing.mddocs/

Testing Utilities

This document covers testing utilities and best practices for testing gRPC services in Go.

Overview

gRPC-Go provides several utilities for testing including in-memory connections, mock servers, and test helpers.

bufconn - In-Memory Testing

Overview

The bufconn package provides in-memory full-duplex network connections for testing without using actual network sockets.

import "google.golang.org/grpc/test/bufconn"

type Listener struct {
    // Has unexported fields
}

// Listen returns an in-memory listener with specified buffer size
func Listen(sz int) *Listener

// Accept blocks until Dial is called, then returns server half of connection
func (l *Listener) Accept() (net.Conn, error)

// Dial creates in-memory connection and returns client half
func (l *Listener) Dial() (net.Conn, error)

// DialContext creates connection with context support
func (l *Listener) DialContext(ctx context.Context) (net.Conn, error)

// Close stops the listener
func (l *Listener) Close() error

// Addr reports the address of the listener
func (l *Listener) Addr() net.Addr

Basic Test Setup

import (
    "context"
    "net"
    "testing"
    "google.golang.org/grpc"
    "google.golang.org/grpc/credentials/insecure"
    "google.golang.org/grpc/test/bufconn"
)

const bufSize = 1024 * 1024

func setupTest(t *testing.T) (*grpc.Server, *grpc.ClientConn, func()) {
    // Create in-memory listener
    lis := bufconn.Listen(bufSize)

    // Create and start server
    server := grpc.NewServer()
    pb.RegisterMyServiceServer(server, &myServiceImpl{})

    go func() {
        if err := server.Serve(lis); err != nil {
            t.Errorf("Server exited with error: %v", err)
        }
    }()

    // Create client connection
    conn, err := grpc.NewClient("bufnet",
        grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
            return lis.DialContext(ctx)
        }),
        grpc.WithTransportCredentials(insecure.NewCredentials()))
    if err != nil {
        t.Fatalf("Failed to dial bufnet: %v", err)
    }

    // Return server, conn, and cleanup function
    cleanup := func() {
        conn.Close()
        server.Stop()
        lis.Close()
    }

    return server, conn, cleanup
}

func TestMyService(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    // Test RPCs
    resp, err := client.MyMethod(context.Background(), &pb.Request{
        Name: "test",
    })
    if err != nil {
        t.Fatalf("MyMethod failed: %v", err)
    }

    if resp.Message != "expected" {
        t.Errorf("Expected 'expected', got %q", resp.Message)
    }
}

Table-Driven Tests

func TestMyService_TableDriven(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    tests := []struct {
        name    string
        input   *pb.Request
        want    *pb.Response
        wantErr bool
    }{
        {
            name:  "valid request",
            input: &pb.Request{Name: "test"},
            want:  &pb.Response{Message: "Hello, test"},
        },
        {
            name:    "empty name",
            input:   &pb.Request{Name: ""},
            wantErr: true,
        },
        {
            name:  "long name",
            input: &pb.Request{Name: strings.Repeat("a", 1000)},
            want:  &pb.Response{Message: "Hello, " + strings.Repeat("a", 1000)},
        },
    }

    for _, tt := range tests {
        t.Run(tt.name, func(t *testing.T) {
            resp, err := client.MyMethod(context.Background(), tt.input)

            if (err != nil) != tt.wantErr {
                t.Errorf("MyMethod() error = %v, wantErr %v", err, tt.wantErr)
                return
            }

            if !tt.wantErr && !proto.Equal(resp, tt.want) {
                t.Errorf("MyMethod() = %v, want %v", resp, tt.want)
            }
        })
    }
}

Testing Streaming RPCs

Server Streaming

func TestServerStreaming(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    stream, err := client.ServerStreamMethod(context.Background(), &pb.Request{})
    if err != nil {
        t.Fatalf("ServerStreamMethod failed: %v", err)
    }

    var responses []*pb.Response
    for {
        resp, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            t.Fatalf("Recv failed: %v", err)
        }
        responses = append(responses, resp)
    }

    if len(responses) != 3 {
        t.Errorf("Expected 3 responses, got %d", len(responses))
    }
}

Client Streaming

func TestClientStreaming(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    stream, err := client.ClientStreamMethod(context.Background())
    if err != nil {
        t.Fatalf("ClientStreamMethod failed: %v", err)
    }

    requests := []*pb.Request{
        {Name: "req1"},
        {Name: "req2"},
        {Name: "req3"},
    }

    for _, req := range requests {
        if err := stream.Send(req); err != nil {
            t.Fatalf("Send failed: %v", err)
        }
    }

    resp, err := stream.CloseAndRecv()
    if err != nil {
        t.Fatalf("CloseAndRecv failed: %v", err)
    }

    if resp.Count != 3 {
        t.Errorf("Expected count=3, got %d", resp.Count)
    }
}

Bidirectional Streaming

func TestBidiStreaming(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    stream, err := client.BidiStreamMethod(context.Background())
    if err != nil {
        t.Fatalf("BidiStreamMethod failed: %v", err)
    }

    // Send requests
    go func() {
        requests := []*pb.Request{
            {Name: "req1"},
            {Name: "req2"},
            {Name: "req3"},
        }
        for _, req := range requests {
            if err := stream.Send(req); err != nil {
                t.Errorf("Send failed: %v", err)
            }
        }
        stream.CloseSend()
    }()

    // Receive responses
    var responses []*pb.Response
    for {
        resp, err := stream.Recv()
        if err == io.EOF {
            break
        }
        if err != nil {
            t.Fatalf("Recv failed: %v", err)
        }
        responses = append(responses, resp)
    }

    if len(responses) != 3 {
        t.Errorf("Expected 3 responses, got %d", len(responses))
    }
}

Testing Error Handling

import (
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/status"
)

func TestErrorHandling(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    // Test not found error
    _, err := client.GetUser(context.Background(), &pb.GetUserRequest{
        Id: "nonexistent",
    })

    if err == nil {
        t.Fatal("Expected error, got nil")
    }

    st := status.Convert(err)
    if st.Code() != codes.NotFound {
        t.Errorf("Expected NotFound, got %v", st.Code())
    }

    if !strings.Contains(st.Message(), "nonexistent") {
        t.Errorf("Error message should contain 'nonexistent': %v", st.Message())
    }
}

Testing with Context

Deadline/Timeout

func TestTimeout(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
    defer cancel()

    _, err := client.SlowMethod(ctx, &pb.Request{})

    if err == nil {
        t.Fatal("Expected timeout error")
    }

    st := status.Convert(err)
    if st.Code() != codes.DeadlineExceeded {
        t.Errorf("Expected DeadlineExceeded, got %v", st.Code())
    }
}

Cancellation

func TestCancellation(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    ctx, cancel := context.WithCancel(context.Background())

    go func() {
        time.Sleep(10 * time.Millisecond)
        cancel()
    }()

    _, err := client.LongRunningMethod(ctx, &pb.Request{})

    if err == nil {
        t.Fatal("Expected cancellation error")
    }

    st := status.Convert(err)
    if st.Code() != codes.Canceled {
        t.Errorf("Expected Canceled, got %v", st.Code())
    }
}

Testing Interceptors

func TestInterceptor(t *testing.T) {
    var interceptorCalled bool

    testInterceptor := func(
        ctx context.Context,
        req any,
        info *grpc.UnaryServerInfo,
        handler grpc.UnaryHandler,
    ) (any, error) {
        interceptorCalled = true
        return handler(ctx, req)
    }

    lis := bufconn.Listen(bufSize)
    server := grpc.NewServer(grpc.UnaryInterceptor(testInterceptor))
    pb.RegisterMyServiceServer(server, &myServiceImpl{})
    defer server.Stop()

    go server.Serve(lis)

    conn, _ := grpc.NewClient("bufnet",
        grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
            return lis.DialContext(ctx)
        }),
        grpc.WithTransportCredentials(insecure.NewCredentials()))
    defer conn.Close()

    client := pb.NewMyServiceClient(conn)
    client.MyMethod(context.Background(), &pb.Request{})

    if !interceptorCalled {
        t.Error("Interceptor was not called")
    }
}

Testing Metadata

import (
    "google.golang.org/grpc/metadata"
)

func TestMetadata(t *testing.T) {
    _, conn, cleanup := setupTest(t)
    defer cleanup()

    client := pb.NewMyServiceClient(conn)

    // Send metadata
    md := metadata.Pairs("key", "value")
    ctx := metadata.NewOutgoingContext(context.Background(), md)

    var header metadata.MD
    _, err := client.MyMethod(ctx, &pb.Request{}, grpc.Header(&header))
    if err != nil {
        t.Fatalf("MyMethod failed: %v", err)
    }

    // Check response header
    if val := header.Get("response-key"); len(val) == 0 {
        t.Error("Expected response header not found")
    }
}

Mock Servers

Simple Mock

type mockServer struct {
    pb.UnimplementedMyServiceServer
}

func (s *mockServer) MyMethod(ctx context.Context, req *pb.Request) (*pb.Response, error) {
    // Mock implementation
    return &pb.Response{Message: "mock response"}, nil
}

func TestWithMock(t *testing.T) {
    lis := bufconn.Listen(bufSize)
    server := grpc.NewServer()
    pb.RegisterMyServiceServer(server, &mockServer{})
    defer server.Stop()

    go server.Serve(lis)

    conn, _ := grpc.NewClient("bufnet",
        grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
            return lis.DialContext(ctx)
        }),
        grpc.WithTransportCredentials(insecure.NewCredentials()))
    defer conn.Close()

    client := pb.NewMyServiceClient(conn)
    resp, err := client.MyMethod(context.Background(), &pb.Request{})
    if err != nil {
        t.Fatal(err)
    }

    if resp.Message != "mock response" {
        t.Errorf("Expected 'mock response', got %q", resp.Message)
    }
}

Best Practices

Test Organization

  1. Use subtests: Organize tests with t.Run()
  2. Table-driven: Use table-driven tests for multiple cases
  3. Setup/teardown: Use helper functions for setup and cleanup
  4. Parallel tests: Use t.Parallel() when tests are independent

Coverage

  1. Happy path: Test successful scenarios
  2. Error cases: Test all error conditions
  3. Edge cases: Test boundary conditions
  4. Streaming: Test all streaming patterns
  5. Cancellation: Test context cancellation
  6. Timeouts: Test deadline exceeded scenarios

Performance

  1. Use bufconn: Faster than real network
  2. Reuse connections: Share connections across tests when possible
  3. Cleanup: Always clean up resources
  4. Parallel execution: Run independent tests in parallel

Example Test Suite

package myservice_test

import (
    "context"
    "io"
    "net"
    "testing"
    "time"

    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/credentials/insecure"
    "google.golang.org/grpc/status"
    "google.golang.org/grpc/test/bufconn"

    pb "mypackage/proto"
    "mypackage/server"
)

const bufSize = 1024 * 1024

func setup(t *testing.T) (pb.MyServiceClient, func()) {
    lis := bufconn.Listen(bufSize)
    s := grpc.NewServer()
    pb.RegisterMyServiceServer(s, server.New())

    go func() {
        if err := s.Serve(lis); err != nil {
            t.Logf("Server exited: %v", err)
        }
    }()

    conn, err := grpc.NewClient("bufnet",
        grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
            return lis.DialContext(ctx)
        }),
        grpc.WithTransportCredentials(insecure.NewCredentials()))
    if err != nil {
        t.Fatalf("Failed to dial: %v", err)
    }

    cleanup := func() {
        conn.Close()
        s.Stop()
        lis.Close()
    }

    return pb.NewMyServiceClient(conn), cleanup
}

func TestMyService(t *testing.T) {
    client, cleanup := setup(t)
    defer cleanup()

    t.Run("UnaryMethod", func(t *testing.T) {
        resp, err := client.MyMethod(context.Background(), &pb.Request{Name: "test"})
        if err != nil {
            t.Fatal(err)
        }
        if resp.Message != "Hello, test" {
            t.Errorf("got %q, want %q", resp.Message, "Hello, test")
        }
    })

    t.Run("StreamMethod", func(t *testing.T) {
        stream, err := client.StreamMethod(context.Background(), &pb.Request{})
        if err != nil {
            t.Fatal(err)
        }

        count := 0
        for {
            _, err := stream.Recv()
            if err == io.EOF {
                break
            }
            if err != nil {
                t.Fatal(err)
            }
            count++
        }

        if count != 5 {
            t.Errorf("got %d messages, want 5", count)
        }
    })

    t.Run("ErrorHandling", func(t *testing.T) {
        _, err := client.MyMethod(context.Background(), &pb.Request{Name: ""})
        if err == nil {
            t.Fatal("expected error")
        }

        if status.Code(err) != codes.InvalidArgument {
            t.Errorf("got %v, want InvalidArgument", status.Code(err))
        }
    })
}