diff --git a/pkg/csi/plugins/nodeserver_test.go b/pkg/csi/plugins/nodeserver_test.go index 28e320217ab..caefbf47948 100644 --- a/pkg/csi/plugins/nodeserver_test.go +++ b/pkg/csi/plugins/nodeserver_test.go @@ -18,15 +18,20 @@ package plugins import ( "context" + "errors" "os" + "os/exec" "path/filepath" + "reflect" "time" + "github.com/agiledragon/gomonkey/v2" "github.com/container-storage-interface/spec/lib/go/csi" "github.com/fluid-cloudnative/fluid/api/v1alpha1" "github.com/fluid-cloudnative/fluid/pkg/common" "github.com/fluid-cloudnative/fluid/pkg/ddc/base" "github.com/fluid-cloudnative/fluid/pkg/utils" + "github.com/fluid-cloudnative/fluid/pkg/utils/cmdguard" csicommon "github.com/kubernetes-csi/drivers/pkg/csi-common" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -289,6 +294,48 @@ var _ = Describe("NodeServer", func() { }) }) + Context("when bind mounting succeeds", func() { + It("should return success after creating the target path", func() { + tempDir, err := os.MkdirTemp("", "node-publish-success-*") + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + }) + + fluidPath := filepath.Join(tempDir, "runtime", testName) + targetPath := filepath.Join(tempDir, "target") + + isMountedPatch := gomonkey.ApplyFunc(utils.IsMounted, func(absPath string) (bool, error) { + return false, os.ErrNotExist + }) + defer isMountedPatch.Reset() + + mountReadyPatch := gomonkey.ApplyFunc(utils.CheckMountReadyAndSubPathExist, func(fluidPath string, mountType string, subPath string) error { + return nil + }) + defer mountReadyPatch.Reset() + + commandPatch := gomonkey.ApplyFunc(cmdguard.Command, func(name string, arg ...string) (*exec.Cmd, error) { + return exec.Command("sh", "-c", "exit 0"), nil + }) + defer commandPatch.Reset() + + req := &csi.NodePublishVolumeRequest{ + VolumeId: testVolumeID, + TargetPath: targetPath, + VolumeContext: map[string]string{ + common.VolumeAttrFluidPath: fluidPath, + }, + } + + resp, err := ns.NodePublishVolume(context.Background(), req) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + Expect(targetPath).To(BeADirectory()) + }) + }) + Context("when skip check mount ready is set", func() { It("should skip mount ready check for mountPod mode", func() { // Use /tmp for test to avoid permission issues @@ -392,6 +439,67 @@ var _ = Describe("NodeServer", func() { Expect(os.IsNotExist(err)).To(BeTrue()) }) }) + + Context("when bind mount cleanup succeeds", func() { + It("should unmount until clean and remove the mount point", func() { + tempDir, err := os.MkdirTemp("", "node-unpublish-success-*") + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + }) + + targetPath := filepath.Join(tempDir, "mounted-target") + Expect(os.MkdirAll(targetPath, 0o750)).To(Succeed()) + + pathExistsPatch := gomonkey.ApplyFunc(utils.MountPathExists, func(path string) (bool, error) { + return true, nil + }) + defer pathExistsPatch.Reset() + + removeSymlinkPatch := gomonkey.ApplyFunc(utils.RemoveSymlink, func(path string) (bool, error) { + return false, nil + }) + defer removeSymlinkPatch.Reset() + + mounterType := reflect.TypeOf(&mount.Mounter{}) + isLikelyNotMountPointCalls := 0 + isLikelyNotMountPointPatch := gomonkey.ApplyMethod(mounterType, "IsLikelyNotMountPoint", func(_ *mount.Mounter, file string) (bool, error) { + isLikelyNotMountPointCalls++ + if isLikelyNotMountPointCalls == 1 { + return false, nil + } + return true, nil + }) + defer isLikelyNotMountPointPatch.Reset() + + unmountCalls := 0 + unmountPatch := gomonkey.ApplyMethod(mounterType, "Unmount", func(_ *mount.Mounter, target string) error { + unmountCalls++ + return nil + }) + defer unmountPatch.Reset() + + cleanupPatch := gomonkey.ApplyFunc(mount.CleanupMountPoint, func(path string, mounter mount.Interface, extensiveMountPointCheck bool) error { + return os.RemoveAll(path) + }) + defer cleanupPatch.Reset() + + req := &csi.NodeUnpublishVolumeRequest{ + VolumeId: testVolumeID, + TargetPath: targetPath, + } + + resp, err := ns.NodeUnpublishVolume(context.Background(), req) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + Expect(unmountCalls).To(Equal(1)) + Expect(isLikelyNotMountPointCalls).To(Equal(2)) + + _, statErr := os.Stat(targetPath) + Expect(os.IsNotExist(statErr)).To(BeTrue()) + }) + }) }) Describe("NodeStageVolume", func() { @@ -487,6 +595,46 @@ var _ = Describe("NodeServer", func() { }) }) + Context("when fuse label key is not provided", func() { + It("should fall back to runtime info to label the node", func() { + tempDir, err := os.MkdirTemp("", "node-stage-fallback-*") + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + }) + + runtimeInfo, err := base.BuildRuntimeInfo(testName, testNamespace, common.AlluxioRuntime) + Expect(err).NotTo(HaveOccurred()) + + runtimeInfoPatch := gomonkey.ApplyFunc(base.GetRuntimeInfo, func(client.Reader, string, string) (base.RuntimeInfoInterface, error) { + return runtimeInfo, nil + }) + defer runtimeInfoPatch.Reset() + + fluidPath := filepath.Join(tempDir, "runtime-fallback") + Expect(os.MkdirAll(fluidPath, 0o750)).To(Succeed()) + + req := &csi.NodeStageVolumeRequest{ + VolumeId: testVolumeID, + VolumeContext: map[string]string{ + common.VolumeAttrName: testName, + common.VolumeAttrNamespace: testNamespace, + common.VolumeAttrFluidPath: fluidPath, + }, + } + + resp, err := ns.NodeStageVolume(context.Background(), req) + + Expect(err).NotTo(HaveOccurred()) + Expect(resp).NotTo(BeNil()) + + updatedNode := &corev1.Node{} + err = mockClient.Get(context.Background(), types.NamespacedName{Name: testNode.Name}, updatedNode) + Expect(err).NotTo(HaveOccurred()) + Expect(updatedNode.Labels).To(HaveKeyWithValue(runtimeInfo.GetFuseLabelName(), "true")) + }) + }) + Context("when SessMgr is required", func() { It("should prepare SessMgr when work directory is specified", func() { workDir := "/tmp/sessmgr-work" @@ -873,20 +1021,58 @@ var _ = Describe("NodeServer", func() { }) Describe("isLikelyNeedUnmount", func() { - var mounter mount.Interface + var fakeMounter mount.Interface BeforeEach(func() { - mounter = mount.NewFakeMounter([]mount.MountPoint{}) + fakeMounter = mount.New("") }) Context("when path does not exist", func() { It("should return false without error", func() { - needUnmount, err := isLikelyNeedUnmount(mounter, "/non/existent/path") + patch := gomonkey.ApplyMethod(reflect.TypeOf(&mount.Mounter{}), "IsLikelyNotMountPoint", func(_ *mount.Mounter, file string) (bool, error) { + return true, os.ErrNotExist + }) + defer patch.Reset() + + needUnmount, err := isLikelyNeedUnmount(fakeMounter, "/non/existent/path") Expect(err).NotTo(HaveOccurred()) Expect(needUnmount).To(BeFalse()) }) }) + + Context("when mounter reports a mount point", func() { + It("should require unmount", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&mount.Mounter{}), "IsLikelyNotMountPoint", func(_ *mount.Mounter, file string) (bool, error) { + return false, nil + }) + defer patch.Reset() + + needUnmount, err := isLikelyNeedUnmount(fakeMounter, "/mounted/path") + + Expect(err).NotTo(HaveOccurred()) + Expect(needUnmount).To(BeTrue()) + }) + }) + + Context("when mounter returns an unexpected error", func() { + It("should return the error", func() { + patch := gomonkey.ApplyMethod(reflect.TypeOf(&mount.Mounter{}), "IsLikelyNotMountPoint", func(_ *mount.Mounter, file string) (bool, error) { + return false, errors.New("unexpected") + }) + defer patch.Reset() + + corruptedPatch := gomonkey.ApplyFunc(mount.IsCorruptedMnt, func(err error) bool { + return false + }) + defer corruptedPatch.Reset() + + needUnmount, err := isLikelyNeedUnmount(fakeMounter, "/error/path") + + Expect(err).To(MatchError("unexpected")) + Expect(needUnmount).To(BeFalse()) + }) + }) }) Describe("checkMountPathExists", func() { @@ -938,6 +1124,20 @@ var _ = Describe("NodeServer", func() { Expect(err).To(BeNil()) }) }) + + Context("when stat returns a non-not-exist error", func() { + It("should wrap the stat error", func() { + statPatch := gomonkey.ApplyFunc(os.Stat, func(name string) (os.FileInfo, error) { + return nil, errors.New("stat failed") + }) + defer statPatch.Reset() + + err := cleanUpBrokenMountPoint("/broken/path") + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to os.Stat(/broken/path)")) + }) + }) }) Describe("prepareSessMgr", func() { @@ -1029,6 +1229,75 @@ var _ = Describe("NodeServer", func() { }) }) + Context("when metadata file cannot be read", func() { + It("should return false", func() { + runtimeInfo, err := base.BuildRuntimeInfo(testName, testNamespace, common.AlluxioRuntime) + Expect(err).NotTo(HaveOccurred()) + + needUpdate := checkIfFuseNeedUpdate(runtimeInfo, "v2") + + Expect(needUpdate).To(BeFalse()) + }) + }) + + Context("when metadata generation differs from latest", func() { + It("should return true", func() { + tempDir, err := os.MkdirTemp("", "fuse-generation-*") + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + }) + + mountRoot := filepath.Join(tempDir, "mount-root") + Expect(os.Setenv(utils.MountRoot, mountRoot)).To(Succeed()) + DeferCleanup(func() { + Expect(os.Unsetenv(utils.MountRoot)).To(Succeed()) + }) + + runtimeInfo, err := base.BuildRuntimeInfo(testName, testNamespace, common.AlluxioRuntime) + Expect(err).NotTo(HaveOccurred()) + + fuseMetaDir := filepath.Join(mountRoot, common.AlluxioRuntime, testNamespace, testName, ".meta", "fuse") + Expect(os.MkdirAll(fuseMetaDir, 0o755)).To(Succeed()) + labelsFile := filepath.Join(fuseMetaDir, utils.MetaDataFuseLabelFileName) + labelsContent := []byte(common.LabelRuntimeFuseGeneration + "=\"v1\"\n") + Expect(os.WriteFile(labelsFile, labelsContent, 0o644)).To(Succeed()) + + needUpdate := checkIfFuseNeedUpdate(runtimeInfo, "v2") + + Expect(needUpdate).To(BeTrue()) + }) + }) + + Context("when metadata generation matches latest", func() { + It("should return false", func() { + tempDir, err := os.MkdirTemp("", "matching-fuse-generation-*") + Expect(err).NotTo(HaveOccurred()) + DeferCleanup(func() { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + }) + + mountRoot := filepath.Join(tempDir, "matching-mount-root") + Expect(os.Setenv(utils.MountRoot, mountRoot)).To(Succeed()) + DeferCleanup(func() { + Expect(os.Unsetenv(utils.MountRoot)).To(Succeed()) + }) + + runtimeInfo, err := base.BuildRuntimeInfo(testName, testNamespace, common.AlluxioRuntime) + Expect(err).NotTo(HaveOccurred()) + + fuseMetaDir := filepath.Join(mountRoot, common.AlluxioRuntime, testNamespace, testName, ".meta", "fuse") + Expect(os.MkdirAll(fuseMetaDir, 0o755)).To(Succeed()) + labelsFile := filepath.Join(fuseMetaDir, utils.MetaDataFuseLabelFileName) + labelsContent := []byte(common.LabelRuntimeFuseGeneration + "=\"v2\"\n") + Expect(os.WriteFile(labelsFile, labelsContent, 0o644)).To(Succeed()) + + needUpdate := checkIfFuseNeedUpdate(runtimeInfo, "v2") + + Expect(needUpdate).To(BeFalse()) + }) + }) + Context("when versions match", func() { It("should return false or handle appropriately", func() { // Create a simple mock runtime info @@ -1182,6 +1451,28 @@ var _ = Describe("NodeServer", func() { Expect(err).NotTo(HaveOccurred()) Expect(cleanFunc).NotTo(BeNil()) }) + + It("should remove the fuse label when the volume is no longer in use", func() { + updatedNode := testNode.DeepCopy() + updatedNode.Labels["test-fuse-label"] = "true" + Expect(mockClient.Update(context.Background(), updatedNode)).To(Succeed()) + + commandPatch := gomonkey.ApplyFunc(cmdguard.Command, func(name string, arg ...string) (*exec.Cmd, error) { + return exec.Command("sh", "-c", "exit 1"), nil + }) + defer commandPatch.Reset() + + cleanFunc, err := ns.getCleanFuseFunc(testVolumeID) + Expect(err).NotTo(HaveOccurred()) + Expect(cleanFunc).NotTo(BeNil()) + + Expect(cleanFunc()).To(Succeed()) + + nodeAfterCleanup := &corev1.Node{} + err = mockClient.Get(context.Background(), types.NamespacedName{Name: testNode.Name}, nodeAfterCleanup) + Expect(err).NotTo(HaveOccurred()) + Expect(nodeAfterCleanup.Labels).NotTo(HaveKey("test-fuse-label")) + }) }) Context("when clean policy is OnRuntimeDeleted", func() { @@ -1248,5 +1539,19 @@ var _ = Describe("NodeServer", func() { _ = err }) }) + + Context("when the command exits with status 1", func() { + It("should return not in use without error", func() { + patch := gomonkey.ApplyFunc(cmdguard.Command, func(name string, arg ...string) (*exec.Cmd, error) { + return exec.Command("sh", "-c", "exit 1"), nil + }) + defer patch.Reset() + + inUse, err := checkMountInUse("test-volume") + + Expect(err).NotTo(HaveOccurred()) + Expect(inUse).To(BeFalse()) + }) + }) }) }) diff --git a/pkg/csi/plugins/plugins_suite_test.go b/pkg/csi/plugins/plugins_suite_test.go index 77130aed5cc..716630699a6 100644 --- a/pkg/csi/plugins/plugins_suite_test.go +++ b/pkg/csi/plugins/plugins_suite_test.go @@ -1,4 +1,4 @@ -package plugins_test +package plugins import ( "testing" diff --git a/pkg/csi/plugins/register_test.go b/pkg/csi/plugins/register_test.go new file mode 100644 index 00000000000..fb7abf99da4 --- /dev/null +++ b/pkg/csi/plugins/register_test.go @@ -0,0 +1,215 @@ +/* +Copyright 2026 The Fluid Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package plugins + +import ( + "context" + stderrors "errors" + "net/http" + "os" + "path/filepath" + + "github.com/agiledragon/gomonkey/v2" + "github.com/fluid-cloudnative/fluid/pkg/csi/config" + "github.com/fluid-cloudnative/fluid/pkg/utils" + "github.com/go-logr/logr" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "k8s.io/apimachinery/pkg/api/meta" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" + ctrlconfig "sigs.k8s.io/controller-runtime/pkg/config" + "sigs.k8s.io/controller-runtime/pkg/healthz" + "sigs.k8s.io/controller-runtime/pkg/manager" + "sigs.k8s.io/controller-runtime/pkg/webhook" +) + +type registerTestManager struct { + manager.Manager + addErr error + addedRunner manager.Runnable + client client.Client + apiReader client.Reader +} + +func (m *registerTestManager) Add(runnable manager.Runnable) error { + m.addedRunner = runnable + return m.addErr +} + +func (m *registerTestManager) Elected() <-chan struct{} { + return nil +} + +func (m *registerTestManager) AddMetricsServerExtraHandler(path string, handler http.Handler) error { + return nil +} + +func (m *registerTestManager) AddHealthzCheck(name string, check healthz.Checker) error { + return nil +} + +func (m *registerTestManager) AddReadyzCheck(name string, check healthz.Checker) error { + return nil +} + +func (m *registerTestManager) Start(ctx context.Context) error { + return nil +} + +func (m *registerTestManager) GetConfig() *rest.Config { + return nil +} + +func (m *registerTestManager) GetScheme() *runtime.Scheme { + return nil +} + +func (m *registerTestManager) GetClient() client.Client { + return m.client +} + +func (m *registerTestManager) GetFieldIndexer() client.FieldIndexer { + return nil +} + +func (m *registerTestManager) GetCache() cache.Cache { + return nil +} + +func (m *registerTestManager) GetEventRecorderFor(name string) record.EventRecorder { + return nil +} + +func (m *registerTestManager) GetRESTMapper() meta.RESTMapper { + return nil +} + +func (m *registerTestManager) GetAPIReader() client.Reader { + return m.apiReader +} + +func (m *registerTestManager) GetWebhookServer() webhook.Server { + return nil +} + +func (m *registerTestManager) GetLogger() logr.Logger { + return logr.Discard() +} + +func (m *registerTestManager) GetControllerOptions() ctrlconfig.Controller { + return ctrlconfig.Controller{} +} + +var _ = Describe("Register", func() { + var tempDir string + + BeforeEach(func() { + var err error + tempDir, err = os.MkdirTemp("", "plugins-register-*") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + if tempDir != "" { + Expect(os.RemoveAll(tempDir)).To(Succeed()) + } + }) + + Describe("getNodeAuthorizedClientFromKubeletConfig", func() { + It("should return no client when the kubelet config file does not exist", func() { + clientset, err := getNodeAuthorizedClientFromKubeletConfig(filepath.Join(tempDir, "missing-kubelet.conf")) + + Expect(err).NotTo(HaveOccurred()) + Expect(clientset).To(BeNil()) + }) + + It("should return stat errors for invalid paths", func() { + statPatch := gomonkey.ApplyFunc(os.Stat, func(string) (os.FileInfo, error) { + return nil, stderrors.New("stat failed") + }) + defer statPatch.Reset() + + clientset, err := getNodeAuthorizedClientFromKubeletConfig(filepath.Join(tempDir, "kubelet.conf")) + + Expect(err).To(HaveOccurred()) + Expect(clientset).To(BeNil()) + Expect(err.Error()).To(ContainSubstring("fail to stat kubelet config file")) + Expect(err.Error()).To(ContainSubstring("stat failed")) + }) + }) + + Describe("Register", func() { + var ( + mgr *registerTestManager + ctx config.RunningContext + ) + + BeforeEach(func() { + mgr = ®isterTestManager{} + ctx = config.RunningContext{ + Config: config.Config{ + NodeId: "test-node", + Endpoint: "unix://" + filepath.Join(tempDir, "csi.sock"), + KubeletConfigPath: filepath.Join(tempDir, "missing-kubelet.conf"), + }, + VolumeLocks: utils.NewVolumeLocks(), + } + }) + + It("should add the constructed driver to the manager", func() { + err := Register(mgr, ctx) + + Expect(err).NotTo(HaveOccurred()) + Expect(mgr.addedRunner).To(BeAssignableToTypeOf(&driver{})) + + addedDriver := mgr.addedRunner.(*driver) + Expect(addedDriver.nodeId).To(Equal(ctx.NodeId)) + Expect(addedDriver.endpoint).To(Equal(ctx.Endpoint)) + Expect(addedDriver.nodeAuthorizedClient).To(BeNil()) + Expect(addedDriver.locks).To(BeIdenticalTo(ctx.VolumeLocks)) + }) + + It("should return manager add errors", func() { + mgr.addErr = stderrors.New("add failed") + + err := Register(mgr, ctx) + + Expect(err).To(MatchError("add failed")) + }) + + It("should return kubelet client initialization errors", func() { + invalidKubeletConfigPath := filepath.Join(tempDir, "kubelet.conf") + Expect(os.WriteFile(invalidKubeletConfigPath, []byte("not-a-kubeconfig"), 0o644)).To(Succeed()) + ctx.KubeletConfigPath = invalidKubeletConfigPath + + err := Register(mgr, ctx) + + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("fail to build kubelet config")) + }) + }) + + Describe("Enabled", func() { + It("should always enable the CSI plugin", func() { + Expect(Enabled()).To(BeTrue()) + }) + }) +})