This document covers testing utilities and best practices for testing gRPC services in Go.
gRPC-Go provides several utilities for testing including in-memory connections, mock servers, and test helpers.
The bufconn package provides in-memory full-duplex network connections for testing without using actual network sockets.
import "google.golang.org/grpc/test/bufconn"
type Listener struct {
// Has unexported fields
}
// Listen returns an in-memory listener with specified buffer size
func Listen(sz int) *Listener
// Accept blocks until Dial is called, then returns server half of connection
func (l *Listener) Accept() (net.Conn, error)
// Dial creates in-memory connection and returns client half
func (l *Listener) Dial() (net.Conn, error)
// DialContext creates connection with context support
func (l *Listener) DialContext(ctx context.Context) (net.Conn, error)
// Close stops the listener
func (l *Listener) Close() error
// Addr reports the address of the listener
func (l *Listener) Addr() net.Addrimport (
"context"
"net"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
)
const bufSize = 1024 * 1024
func setupTest(t *testing.T) (*grpc.Server, *grpc.ClientConn, func()) {
// Create in-memory listener
lis := bufconn.Listen(bufSize)
// Create and start server
server := grpc.NewServer()
pb.RegisterMyServiceServer(server, &myServiceImpl{})
go func() {
if err := server.Serve(lis); err != nil {
t.Errorf("Server exited with error: %v", err)
}
}()
// Create client connection
conn, err := grpc.NewClient("bufnet",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to dial bufnet: %v", err)
}
// Return server, conn, and cleanup function
cleanup := func() {
conn.Close()
server.Stop()
lis.Close()
}
return server, conn, cleanup
}
func TestMyService(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
// Test RPCs
resp, err := client.MyMethod(context.Background(), &pb.Request{
Name: "test",
})
if err != nil {
t.Fatalf("MyMethod failed: %v", err)
}
if resp.Message != "expected" {
t.Errorf("Expected 'expected', got %q", resp.Message)
}
}func TestMyService_TableDriven(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
tests := []struct {
name string
input *pb.Request
want *pb.Response
wantErr bool
}{
{
name: "valid request",
input: &pb.Request{Name: "test"},
want: &pb.Response{Message: "Hello, test"},
},
{
name: "empty name",
input: &pb.Request{Name: ""},
wantErr: true,
},
{
name: "long name",
input: &pb.Request{Name: strings.Repeat("a", 1000)},
want: &pb.Response{Message: "Hello, " + strings.Repeat("a", 1000)},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := client.MyMethod(context.Background(), tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("MyMethod() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !proto.Equal(resp, tt.want) {
t.Errorf("MyMethod() = %v, want %v", resp, tt.want)
}
})
}
}func TestServerStreaming(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
stream, err := client.ServerStreamMethod(context.Background(), &pb.Request{})
if err != nil {
t.Fatalf("ServerStreamMethod failed: %v", err)
}
var responses []*pb.Response
for {
resp, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
responses = append(responses, resp)
}
if len(responses) != 3 {
t.Errorf("Expected 3 responses, got %d", len(responses))
}
}func TestClientStreaming(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
stream, err := client.ClientStreamMethod(context.Background())
if err != nil {
t.Fatalf("ClientStreamMethod failed: %v", err)
}
requests := []*pb.Request{
{Name: "req1"},
{Name: "req2"},
{Name: "req3"},
}
for _, req := range requests {
if err := stream.Send(req); err != nil {
t.Fatalf("Send failed: %v", err)
}
}
resp, err := stream.CloseAndRecv()
if err != nil {
t.Fatalf("CloseAndRecv failed: %v", err)
}
if resp.Count != 3 {
t.Errorf("Expected count=3, got %d", resp.Count)
}
}func TestBidiStreaming(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
stream, err := client.BidiStreamMethod(context.Background())
if err != nil {
t.Fatalf("BidiStreamMethod failed: %v", err)
}
// Send requests
go func() {
requests := []*pb.Request{
{Name: "req1"},
{Name: "req2"},
{Name: "req3"},
}
for _, req := range requests {
if err := stream.Send(req); err != nil {
t.Errorf("Send failed: %v", err)
}
}
stream.CloseSend()
}()
// Receive responses
var responses []*pb.Response
for {
resp, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatalf("Recv failed: %v", err)
}
responses = append(responses, resp)
}
if len(responses) != 3 {
t.Errorf("Expected 3 responses, got %d", len(responses))
}
}import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
func TestErrorHandling(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
// Test not found error
_, err := client.GetUser(context.Background(), &pb.GetUserRequest{
Id: "nonexistent",
})
if err == nil {
t.Fatal("Expected error, got nil")
}
st := status.Convert(err)
if st.Code() != codes.NotFound {
t.Errorf("Expected NotFound, got %v", st.Code())
}
if !strings.Contains(st.Message(), "nonexistent") {
t.Errorf("Error message should contain 'nonexistent': %v", st.Message())
}
}func TestTimeout(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
_, err := client.SlowMethod(ctx, &pb.Request{})
if err == nil {
t.Fatal("Expected timeout error")
}
st := status.Convert(err)
if st.Code() != codes.DeadlineExceeded {
t.Errorf("Expected DeadlineExceeded, got %v", st.Code())
}
}func TestCancellation(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
_, err := client.LongRunningMethod(ctx, &pb.Request{})
if err == nil {
t.Fatal("Expected cancellation error")
}
st := status.Convert(err)
if st.Code() != codes.Canceled {
t.Errorf("Expected Canceled, got %v", st.Code())
}
}func TestInterceptor(t *testing.T) {
var interceptorCalled bool
testInterceptor := func(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
interceptorCalled = true
return handler(ctx, req)
}
lis := bufconn.Listen(bufSize)
server := grpc.NewServer(grpc.UnaryInterceptor(testInterceptor))
pb.RegisterMyServiceServer(server, &myServiceImpl{})
defer server.Stop()
go server.Serve(lis)
conn, _ := grpc.NewClient("bufnet",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
client := pb.NewMyServiceClient(conn)
client.MyMethod(context.Background(), &pb.Request{})
if !interceptorCalled {
t.Error("Interceptor was not called")
}
}import (
"google.golang.org/grpc/metadata"
)
func TestMetadata(t *testing.T) {
_, conn, cleanup := setupTest(t)
defer cleanup()
client := pb.NewMyServiceClient(conn)
// Send metadata
md := metadata.Pairs("key", "value")
ctx := metadata.NewOutgoingContext(context.Background(), md)
var header metadata.MD
_, err := client.MyMethod(ctx, &pb.Request{}, grpc.Header(&header))
if err != nil {
t.Fatalf("MyMethod failed: %v", err)
}
// Check response header
if val := header.Get("response-key"); len(val) == 0 {
t.Error("Expected response header not found")
}
}type mockServer struct {
pb.UnimplementedMyServiceServer
}
func (s *mockServer) MyMethod(ctx context.Context, req *pb.Request) (*pb.Response, error) {
// Mock implementation
return &pb.Response{Message: "mock response"}, nil
}
func TestWithMock(t *testing.T) {
lis := bufconn.Listen(bufSize)
server := grpc.NewServer()
pb.RegisterMyServiceServer(server, &mockServer{})
defer server.Stop()
go server.Serve(lis)
conn, _ := grpc.NewClient("bufnet",
grpc.WithContextDialer(func(ctx context.Context, s string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
defer conn.Close()
client := pb.NewMyServiceClient(conn)
resp, err := client.MyMethod(context.Background(), &pb.Request{})
if err != nil {
t.Fatal(err)
}
if resp.Message != "mock response" {
t.Errorf("Expected 'mock response', got %q", resp.Message)
}
}t.Run()t.Parallel() when tests are independentpackage myservice_test
import (
"context"
"io"
"net"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
pb "mypackage/proto"
"mypackage/server"
)
const bufSize = 1024 * 1024
func setup(t *testing.T) (pb.MyServiceClient, func()) {
lis := bufconn.Listen(bufSize)
s := grpc.NewServer()
pb.RegisterMyServiceServer(s, server.New())
go func() {
if err := s.Serve(lis); err != nil {
t.Logf("Server exited: %v", err)
}
}()
conn, err := grpc.NewClient("bufnet",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("Failed to dial: %v", err)
}
cleanup := func() {
conn.Close()
s.Stop()
lis.Close()
}
return pb.NewMyServiceClient(conn), cleanup
}
func TestMyService(t *testing.T) {
client, cleanup := setup(t)
defer cleanup()
t.Run("UnaryMethod", func(t *testing.T) {
resp, err := client.MyMethod(context.Background(), &pb.Request{Name: "test"})
if err != nil {
t.Fatal(err)
}
if resp.Message != "Hello, test" {
t.Errorf("got %q, want %q", resp.Message, "Hello, test")
}
})
t.Run("StreamMethod", func(t *testing.T) {
stream, err := client.StreamMethod(context.Background(), &pb.Request{})
if err != nil {
t.Fatal(err)
}
count := 0
for {
_, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
t.Fatal(err)
}
count++
}
if count != 5 {
t.Errorf("got %d messages, want 5", count)
}
})
t.Run("ErrorHandling", func(t *testing.T) {
_, err := client.MyMethod(context.Background(), &pb.Request{Name: ""})
if err == nil {
t.Fatal("expected error")
}
if status.Code(err) != codes.InvalidArgument {
t.Errorf("got %v, want InvalidArgument", status.Code(err))
}
})
}